- The class name of the JDBC driver needed to connect to this URL. This class with be loaded
+ The class name of the JDBC driver needed to connect to this URL. This class will be loaded
on the master and workers before running an JDBC commands to allow the driver to
register itself with the JDBC subsystem.
diff --git a/examples/pom.xml b/examples/pom.xml
index afd7c6d52f0dd..5b04b4f8d6ca0 100644
--- a/examples/pom.xml
+++ b/examples/pom.xml
@@ -245,7 +245,7 @@
com.twitteralgebird-core_${scala.binary.version}
- 0.8.1
+ 0.9.0org.scalacheck
@@ -390,11 +390,6 @@
spark-streaming-kinesis-asl_${scala.binary.version}${project.version}
-
- org.apache.httpcomponents
- httpclient
- ${commons.httpclient.version}
-
diff --git a/examples/src/main/python/sql.py b/examples/src/main/python/sql.py
index 87d7b088f077b..2c188759328f2 100644
--- a/examples/src/main/python/sql.py
+++ b/examples/src/main/python/sql.py
@@ -18,6 +18,7 @@
from __future__ import print_function
import os
+import sys
from pyspark import SparkContext
from pyspark.sql import SQLContext
@@ -50,7 +51,11 @@
# A JSON dataset is pointed to by path.
# The path can be either a single text file or a directory storing text files.
- path = os.path.join(os.environ['SPARK_HOME'], "examples/src/main/resources/people.json")
+ if len(sys.argv) < 2:
+ path = "file://" + \
+ os.path.join(os.environ['SPARK_HOME'], "examples/src/main/resources/people.json")
+ else:
+ path = sys.argv[1]
# Create a DataFrame from the file(s) pointed to by path
people = sqlContext.jsonFile(path)
# root
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala
index 921b396e799e7..9002e99d82ad3 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala
@@ -22,10 +22,9 @@ import scala.language.reflectiveCalls
import scopt.OptionParser
-import org.apache.spark.ml.tree.DecisionTreeModel
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.examples.mllib.AbstractParams
-import org.apache.spark.ml.{Pipeline, PipelineStage}
+import org.apache.spark.ml.{Pipeline, PipelineStage, Transformer}
import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, DecisionTreeClassifier}
import org.apache.spark.ml.feature.{VectorIndexer, StringIndexer}
import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor}
@@ -44,6 +43,13 @@ import org.apache.spark.sql.{SQLContext, DataFrame}
* {{{
* ./bin/run-example ml.DecisionTreeExample [options]
* }}}
+ * Note that Decision Trees can take a large amount of memory. If the run-example command above
+ * fails, try running via spark-submit and specifying the amount of memory as at least 1g.
+ * For local mode, run
+ * {{{
+ * ./bin/spark-submit --class org.apache.spark.examples.ml.DecisionTreeExample --driver-memory 1g
+ * [examples JAR path] [options]
+ * }}}
* If you use it as a template to create your own app, please use `spark-submit` to submit your app.
*/
object DecisionTreeExample {
@@ -57,8 +63,6 @@ object DecisionTreeExample {
maxBins: Int = 32,
minInstancesPerNode: Int = 1,
minInfoGain: Double = 0.0,
- numTrees: Int = 1,
- featureSubsetStrategy: String = "auto",
fracTest: Double = 0.2,
cacheNodeIds: Boolean = false,
checkpointDir: Option[String] = None,
@@ -70,7 +74,7 @@ object DecisionTreeExample {
val parser = new OptionParser[Params]("DecisionTreeExample") {
head("DecisionTreeExample: an example decision tree app.")
opt[String]("algo")
- .text(s"algorithm (Classification, Regression), default: ${defaultParams.algo}")
+ .text(s"algorithm (classification, regression), default: ${defaultParams.algo}")
.action((x, c) => c.copy(algo = x))
opt[Int]("maxDepth")
.text(s"max depth of the tree, default: ${defaultParams.maxDepth}")
@@ -116,8 +120,8 @@ object DecisionTreeExample {
.required()
.action((x, c) => c.copy(input = x))
checkConfig { params =>
- if (params.fracTest < 0 || params.fracTest > 1) {
- failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1].")
+ if (params.fracTest < 0 || params.fracTest >= 1) {
+ failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1).")
} else {
success
}
@@ -193,9 +197,18 @@ object DecisionTreeExample {
throw new IllegalArgumentException("Algo ${params.algo} not supported.")
}
}
- val dataframes = splits.map(_.toDF()).map(labelsToStrings).map(_.cache())
+ val dataframes = splits.map(_.toDF()).map(labelsToStrings)
+ val training = dataframes(0).cache()
+ val test = dataframes(1).cache()
+
+ val numTraining = training.count()
+ val numTest = test.count()
+ val numFeatures = training.select("features").first().getAs[Vector](0).size
+ println("Loaded data:")
+ println(s" numTraining = $numTraining, numTest = $numTest")
+ println(s" numFeatures = $numFeatures")
- (dataframes(0), dataframes(1))
+ (training, test)
}
def run(params: Params) {
@@ -210,30 +223,28 @@ object DecisionTreeExample {
val (training: DataFrame, test: DataFrame) =
loadDatasets(sc, params.input, params.dataFormat, params.testInput, algo, params.fracTest)
- val numTraining = training.count()
- val numTest = test.count()
- val numFeatures = training.select("features").first().getAs[Vector](0).size
- println("Loaded data:")
- println(s" numTraining = $numTraining, numTest = $numTest")
- println(s" numFeatures = $numFeatures")
-
// Set up Pipeline
val stages = new mutable.ArrayBuffer[PipelineStage]()
// (1) For classification, re-index classes.
val labelColName = if (algo == "classification") "indexedLabel" else "label"
if (algo == "classification") {
- val labelIndexer = new StringIndexer().setInputCol("labelString").setOutputCol(labelColName)
+ val labelIndexer = new StringIndexer()
+ .setInputCol("labelString")
+ .setOutputCol(labelColName)
stages += labelIndexer
}
// (2) Identify categorical features using VectorIndexer.
// Features with more than maxCategories values will be treated as continuous.
- val featuresIndexer = new VectorIndexer().setInputCol("features")
- .setOutputCol("indexedFeatures").setMaxCategories(10)
+ val featuresIndexer = new VectorIndexer()
+ .setInputCol("features")
+ .setOutputCol("indexedFeatures")
+ .setMaxCategories(10)
stages += featuresIndexer
- // (3) Learn DecisionTree
+ // (3) Learn Decision Tree
val dt = algo match {
case "classification" =>
- new DecisionTreeClassifier().setFeaturesCol("indexedFeatures")
+ new DecisionTreeClassifier()
+ .setFeaturesCol("indexedFeatures")
.setLabelCol(labelColName)
.setMaxDepth(params.maxDepth)
.setMaxBins(params.maxBins)
@@ -242,7 +253,8 @@ object DecisionTreeExample {
.setCacheNodeIds(params.cacheNodeIds)
.setCheckpointInterval(params.checkpointInterval)
case "regression" =>
- new DecisionTreeRegressor().setFeaturesCol("indexedFeatures")
+ new DecisionTreeRegressor()
+ .setFeaturesCol("indexedFeatures")
.setLabelCol(labelColName)
.setMaxDepth(params.maxDepth)
.setMaxBins(params.maxBins)
@@ -262,62 +274,86 @@ object DecisionTreeExample {
println(s"Training time: $elapsedTime seconds")
// Get the trained Decision Tree from the fitted PipelineModel
- val treeModel: DecisionTreeModel = algo match {
+ algo match {
case "classification" =>
- pipelineModel.getModel[DecisionTreeClassificationModel](
+ val treeModel = pipelineModel.getModel[DecisionTreeClassificationModel](
dt.asInstanceOf[DecisionTreeClassifier])
+ if (treeModel.numNodes < 20) {
+ println(treeModel.toDebugString) // Print full model.
+ } else {
+ println(treeModel) // Print model summary.
+ }
case "regression" =>
- pipelineModel.getModel[DecisionTreeRegressionModel](dt.asInstanceOf[DecisionTreeRegressor])
- case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.")
- }
- if (treeModel.numNodes < 20) {
- println(treeModel.toDebugString) // Print full model.
- } else {
- println(treeModel) // Print model summary.
- }
-
- // Predict on training
- val trainingFullPredictions = pipelineModel.transform(training).cache()
- val trainingPredictions = trainingFullPredictions.select("prediction")
- .map(_.getDouble(0))
- val trainingLabels = trainingFullPredictions.select(labelColName).map(_.getDouble(0))
- // Predict on test data
- val testFullPredictions = pipelineModel.transform(test).cache()
- val testPredictions = testFullPredictions.select("prediction")
- .map(_.getDouble(0))
- val testLabels = testFullPredictions.select(labelColName).map(_.getDouble(0))
-
- // For classification, print number of classes for reference.
- if (algo == "classification") {
- val numClasses =
- MetadataUtils.getNumClasses(trainingFullPredictions.schema(labelColName)) match {
- case Some(n) => n
- case None => throw new RuntimeException(
- "DecisionTreeExample had unknown failure when indexing labels for classification.")
+ val treeModel = pipelineModel.getModel[DecisionTreeRegressionModel](
+ dt.asInstanceOf[DecisionTreeRegressor])
+ if (treeModel.numNodes < 20) {
+ println(treeModel.toDebugString) // Print full model.
+ } else {
+ println(treeModel) // Print model summary.
}
- println(s"numClasses = $numClasses.")
+ case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.")
}
// Evaluate model on training, test data
algo match {
case "classification" =>
- val trainingAccuracy =
- new MulticlassMetrics(trainingPredictions.zip(trainingLabels)).precision
- println(s"Train accuracy = $trainingAccuracy")
- val testAccuracy =
- new MulticlassMetrics(testPredictions.zip(testLabels)).precision
- println(s"Test accuracy = $testAccuracy")
+ println("Training data results:")
+ evaluateClassificationModel(pipelineModel, training, labelColName)
+ println("Test data results:")
+ evaluateClassificationModel(pipelineModel, test, labelColName)
case "regression" =>
- val trainingRMSE =
- new RegressionMetrics(trainingPredictions.zip(trainingLabels)).rootMeanSquaredError
- println(s"Training root mean squared error (RMSE) = $trainingRMSE")
- val testRMSE =
- new RegressionMetrics(testPredictions.zip(testLabels)).rootMeanSquaredError
- println(s"Test root mean squared error (RMSE) = $testRMSE")
+ println("Training data results:")
+ evaluateRegressionModel(pipelineModel, training, labelColName)
+ println("Test data results:")
+ evaluateRegressionModel(pipelineModel, test, labelColName)
case _ =>
throw new IllegalArgumentException("Algo ${params.algo} not supported.")
}
sc.stop()
}
+
+ /**
+ * Evaluate the given ClassificationModel on data. Print the results.
+ * @param model Must fit ClassificationModel abstraction
+ * @param data DataFrame with "prediction" and labelColName columns
+ * @param labelColName Name of the labelCol parameter for the model
+ *
+ * TODO: Change model type to ClassificationModel once that API is public. SPARK-5995
+ */
+ private[ml] def evaluateClassificationModel(
+ model: Transformer,
+ data: DataFrame,
+ labelColName: String): Unit = {
+ val fullPredictions = model.transform(data).cache()
+ val predictions = fullPredictions.select("prediction").map(_.getDouble(0))
+ val labels = fullPredictions.select(labelColName).map(_.getDouble(0))
+ // Print number of classes for reference
+ val numClasses = MetadataUtils.getNumClasses(fullPredictions.schema(labelColName)) match {
+ case Some(n) => n
+ case None => throw new RuntimeException(
+ "Unknown failure when indexing labels for classification.")
+ }
+ val accuracy = new MulticlassMetrics(predictions.zip(labels)).precision
+ println(s" Accuracy ($numClasses classes): $accuracy")
+ }
+
+ /**
+ * Evaluate the given RegressionModel on data. Print the results.
+ * @param model Must fit RegressionModel abstraction
+ * @param data DataFrame with "prediction" and labelColName columns
+ * @param labelColName Name of the labelCol parameter for the model
+ *
+ * TODO: Change model type to RegressionModel once that API is public. SPARK-5995
+ */
+ private[ml] def evaluateRegressionModel(
+ model: Transformer,
+ data: DataFrame,
+ labelColName: String): Unit = {
+ val fullPredictions = model.transform(data).cache()
+ val predictions = fullPredictions.select("prediction").map(_.getDouble(0))
+ val labels = fullPredictions.select(labelColName).map(_.getDouble(0))
+ val RMSE = new RegressionMetrics(predictions.zip(labels)).rootMeanSquaredError
+ println(s" Root mean squared error (RMSE): $RMSE")
+ }
}
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala
new file mode 100644
index 0000000000000..5fccb142d4c3d
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala
@@ -0,0 +1,238 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.ml
+
+import scala.collection.mutable
+import scala.language.reflectiveCalls
+
+import scopt.OptionParser
+
+import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.examples.mllib.AbstractParams
+import org.apache.spark.ml.{Pipeline, PipelineStage}
+import org.apache.spark.ml.classification.{GBTClassificationModel, GBTClassifier}
+import org.apache.spark.ml.feature.{StringIndexer, VectorIndexer}
+import org.apache.spark.ml.regression.{GBTRegressionModel, GBTRegressor}
+import org.apache.spark.sql.DataFrame
+
+
+/**
+ * An example runner for decision trees. Run with
+ * {{{
+ * ./bin/run-example ml.GBTExample [options]
+ * }}}
+ * Decision Trees and ensembles can take a large amount of memory. If the run-example command
+ * above fails, try running via spark-submit and specifying the amount of memory as at least 1g.
+ * For local mode, run
+ * {{{
+ * ./bin/spark-submit --class org.apache.spark.examples.ml.GBTExample --driver-memory 1g
+ * [examples JAR path] [options]
+ * }}}
+ * If you use it as a template to create your own app, please use `spark-submit` to submit your app.
+ */
+object GBTExample {
+
+ case class Params(
+ input: String = null,
+ testInput: String = "",
+ dataFormat: String = "libsvm",
+ algo: String = "classification",
+ maxDepth: Int = 5,
+ maxBins: Int = 32,
+ minInstancesPerNode: Int = 1,
+ minInfoGain: Double = 0.0,
+ maxIter: Int = 10,
+ fracTest: Double = 0.2,
+ cacheNodeIds: Boolean = false,
+ checkpointDir: Option[String] = None,
+ checkpointInterval: Int = 10) extends AbstractParams[Params]
+
+ def main(args: Array[String]) {
+ val defaultParams = Params()
+
+ val parser = new OptionParser[Params]("GBTExample") {
+ head("GBTExample: an example Gradient-Boosted Trees app.")
+ opt[String]("algo")
+ .text(s"algorithm (classification, regression), default: ${defaultParams.algo}")
+ .action((x, c) => c.copy(algo = x))
+ opt[Int]("maxDepth")
+ .text(s"max depth of the tree, default: ${defaultParams.maxDepth}")
+ .action((x, c) => c.copy(maxDepth = x))
+ opt[Int]("maxBins")
+ .text(s"max number of bins, default: ${defaultParams.maxBins}")
+ .action((x, c) => c.copy(maxBins = x))
+ opt[Int]("minInstancesPerNode")
+ .text(s"min number of instances required at child nodes to create the parent split," +
+ s" default: ${defaultParams.minInstancesPerNode}")
+ .action((x, c) => c.copy(minInstancesPerNode = x))
+ opt[Double]("minInfoGain")
+ .text(s"min info gain required to create a split, default: ${defaultParams.minInfoGain}")
+ .action((x, c) => c.copy(minInfoGain = x))
+ opt[Int]("maxIter")
+ .text(s"number of trees in ensemble, default: ${defaultParams.maxIter}")
+ .action((x, c) => c.copy(maxIter = x))
+ opt[Double]("fracTest")
+ .text(s"fraction of data to hold out for testing. If given option testInput, " +
+ s"this option is ignored. default: ${defaultParams.fracTest}")
+ .action((x, c) => c.copy(fracTest = x))
+ opt[Boolean]("cacheNodeIds")
+ .text(s"whether to use node Id cache during training, " +
+ s"default: ${defaultParams.cacheNodeIds}")
+ .action((x, c) => c.copy(cacheNodeIds = x))
+ opt[String]("checkpointDir")
+ .text(s"checkpoint directory where intermediate node Id caches will be stored, " +
+ s"default: ${
+ defaultParams.checkpointDir match {
+ case Some(strVal) => strVal
+ case None => "None"
+ }
+ }")
+ .action((x, c) => c.copy(checkpointDir = Some(x)))
+ opt[Int]("checkpointInterval")
+ .text(s"how often to checkpoint the node Id cache, " +
+ s"default: ${defaultParams.checkpointInterval}")
+ .action((x, c) => c.copy(checkpointInterval = x))
+ opt[String]("testInput")
+ .text(s"input path to test dataset. If given, option fracTest is ignored." +
+ s" default: ${defaultParams.testInput}")
+ .action((x, c) => c.copy(testInput = x))
+ opt[String]("")
+ .text("data format: libsvm (default), dense (deprecated in Spark v1.1)")
+ .action((x, c) => c.copy(dataFormat = x))
+ arg[String]("")
+ .text("input path to labeled examples")
+ .required()
+ .action((x, c) => c.copy(input = x))
+ checkConfig { params =>
+ if (params.fracTest < 0 || params.fracTest >= 1) {
+ failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1).")
+ } else {
+ success
+ }
+ }
+ }
+
+ parser.parse(args, defaultParams).map { params =>
+ run(params)
+ }.getOrElse {
+ sys.exit(1)
+ }
+ }
+
+ def run(params: Params) {
+ val conf = new SparkConf().setAppName(s"GBTExample with $params")
+ val sc = new SparkContext(conf)
+ params.checkpointDir.foreach(sc.setCheckpointDir)
+ val algo = params.algo.toLowerCase
+
+ println(s"GBTExample with parameters:\n$params")
+
+ // Load training and test data and cache it.
+ val (training: DataFrame, test: DataFrame) = DecisionTreeExample.loadDatasets(sc, params.input,
+ params.dataFormat, params.testInput, algo, params.fracTest)
+
+ // Set up Pipeline
+ val stages = new mutable.ArrayBuffer[PipelineStage]()
+ // (1) For classification, re-index classes.
+ val labelColName = if (algo == "classification") "indexedLabel" else "label"
+ if (algo == "classification") {
+ val labelIndexer = new StringIndexer()
+ .setInputCol("labelString")
+ .setOutputCol(labelColName)
+ stages += labelIndexer
+ }
+ // (2) Identify categorical features using VectorIndexer.
+ // Features with more than maxCategories values will be treated as continuous.
+ val featuresIndexer = new VectorIndexer()
+ .setInputCol("features")
+ .setOutputCol("indexedFeatures")
+ .setMaxCategories(10)
+ stages += featuresIndexer
+ // (3) Learn GBT
+ val dt = algo match {
+ case "classification" =>
+ new GBTClassifier()
+ .setFeaturesCol("indexedFeatures")
+ .setLabelCol(labelColName)
+ .setMaxDepth(params.maxDepth)
+ .setMaxBins(params.maxBins)
+ .setMinInstancesPerNode(params.minInstancesPerNode)
+ .setMinInfoGain(params.minInfoGain)
+ .setCacheNodeIds(params.cacheNodeIds)
+ .setCheckpointInterval(params.checkpointInterval)
+ .setMaxIter(params.maxIter)
+ case "regression" =>
+ new GBTRegressor()
+ .setFeaturesCol("indexedFeatures")
+ .setLabelCol(labelColName)
+ .setMaxDepth(params.maxDepth)
+ .setMaxBins(params.maxBins)
+ .setMinInstancesPerNode(params.minInstancesPerNode)
+ .setMinInfoGain(params.minInfoGain)
+ .setCacheNodeIds(params.cacheNodeIds)
+ .setCheckpointInterval(params.checkpointInterval)
+ .setMaxIter(params.maxIter)
+ case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.")
+ }
+ stages += dt
+ val pipeline = new Pipeline().setStages(stages.toArray)
+
+ // Fit the Pipeline
+ val startTime = System.nanoTime()
+ val pipelineModel = pipeline.fit(training)
+ val elapsedTime = (System.nanoTime() - startTime) / 1e9
+ println(s"Training time: $elapsedTime seconds")
+
+ // Get the trained GBT from the fitted PipelineModel
+ algo match {
+ case "classification" =>
+ val rfModel = pipelineModel.getModel[GBTClassificationModel](dt.asInstanceOf[GBTClassifier])
+ if (rfModel.totalNumNodes < 30) {
+ println(rfModel.toDebugString) // Print full model.
+ } else {
+ println(rfModel) // Print model summary.
+ }
+ case "regression" =>
+ val rfModel = pipelineModel.getModel[GBTRegressionModel](dt.asInstanceOf[GBTRegressor])
+ if (rfModel.totalNumNodes < 30) {
+ println(rfModel.toDebugString) // Print full model.
+ } else {
+ println(rfModel) // Print model summary.
+ }
+ case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.")
+ }
+
+ // Evaluate model on training, test data
+ algo match {
+ case "classification" =>
+ println("Training data results:")
+ DecisionTreeExample.evaluateClassificationModel(pipelineModel, training, labelColName)
+ println("Test data results:")
+ DecisionTreeExample.evaluateClassificationModel(pipelineModel, test, labelColName)
+ case "regression" =>
+ println("Training data results:")
+ DecisionTreeExample.evaluateRegressionModel(pipelineModel, training, labelColName)
+ println("Test data results:")
+ DecisionTreeExample.evaluateRegressionModel(pipelineModel, test, labelColName)
+ case _ =>
+ throw new IllegalArgumentException("Algo ${params.algo} not supported.")
+ }
+
+ sc.stop()
+ }
+}
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala
new file mode 100644
index 0000000000000..9b909324ec82a
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala
@@ -0,0 +1,248 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.ml
+
+import scala.collection.mutable
+import scala.language.reflectiveCalls
+
+import scopt.OptionParser
+
+import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.examples.mllib.AbstractParams
+import org.apache.spark.ml.{Pipeline, PipelineStage}
+import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier}
+import org.apache.spark.ml.feature.{StringIndexer, VectorIndexer}
+import org.apache.spark.ml.regression.{RandomForestRegressionModel, RandomForestRegressor}
+import org.apache.spark.sql.DataFrame
+
+
+/**
+ * An example runner for decision trees. Run with
+ * {{{
+ * ./bin/run-example ml.RandomForestExample [options]
+ * }}}
+ * Decision Trees and ensembles can take a large amount of memory. If the run-example command
+ * above fails, try running via spark-submit and specifying the amount of memory as at least 1g.
+ * For local mode, run
+ * {{{
+ * ./bin/spark-submit --class org.apache.spark.examples.ml.RandomForestExample --driver-memory 1g
+ * [examples JAR path] [options]
+ * }}}
+ * If you use it as a template to create your own app, please use `spark-submit` to submit your app.
+ */
+object RandomForestExample {
+
+ case class Params(
+ input: String = null,
+ testInput: String = "",
+ dataFormat: String = "libsvm",
+ algo: String = "classification",
+ maxDepth: Int = 5,
+ maxBins: Int = 32,
+ minInstancesPerNode: Int = 1,
+ minInfoGain: Double = 0.0,
+ numTrees: Int = 10,
+ featureSubsetStrategy: String = "auto",
+ fracTest: Double = 0.2,
+ cacheNodeIds: Boolean = false,
+ checkpointDir: Option[String] = None,
+ checkpointInterval: Int = 10) extends AbstractParams[Params]
+
+ def main(args: Array[String]) {
+ val defaultParams = Params()
+
+ val parser = new OptionParser[Params]("RandomForestExample") {
+ head("RandomForestExample: an example random forest app.")
+ opt[String]("algo")
+ .text(s"algorithm (classification, regression), default: ${defaultParams.algo}")
+ .action((x, c) => c.copy(algo = x))
+ opt[Int]("maxDepth")
+ .text(s"max depth of the tree, default: ${defaultParams.maxDepth}")
+ .action((x, c) => c.copy(maxDepth = x))
+ opt[Int]("maxBins")
+ .text(s"max number of bins, default: ${defaultParams.maxBins}")
+ .action((x, c) => c.copy(maxBins = x))
+ opt[Int]("minInstancesPerNode")
+ .text(s"min number of instances required at child nodes to create the parent split," +
+ s" default: ${defaultParams.minInstancesPerNode}")
+ .action((x, c) => c.copy(minInstancesPerNode = x))
+ opt[Double]("minInfoGain")
+ .text(s"min info gain required to create a split, default: ${defaultParams.minInfoGain}")
+ .action((x, c) => c.copy(minInfoGain = x))
+ opt[Int]("numTrees")
+ .text(s"number of trees in ensemble, default: ${defaultParams.numTrees}")
+ .action((x, c) => c.copy(numTrees = x))
+ opt[String]("featureSubsetStrategy")
+ .text(s"number of features to use per node (supported:" +
+ s" ${RandomForestClassifier.supportedFeatureSubsetStrategies.mkString(",")})," +
+ s" default: ${defaultParams.numTrees}")
+ .action((x, c) => c.copy(featureSubsetStrategy = x))
+ opt[Double]("fracTest")
+ .text(s"fraction of data to hold out for testing. If given option testInput, " +
+ s"this option is ignored. default: ${defaultParams.fracTest}")
+ .action((x, c) => c.copy(fracTest = x))
+ opt[Boolean]("cacheNodeIds")
+ .text(s"whether to use node Id cache during training, " +
+ s"default: ${defaultParams.cacheNodeIds}")
+ .action((x, c) => c.copy(cacheNodeIds = x))
+ opt[String]("checkpointDir")
+ .text(s"checkpoint directory where intermediate node Id caches will be stored, " +
+ s"default: ${
+ defaultParams.checkpointDir match {
+ case Some(strVal) => strVal
+ case None => "None"
+ }
+ }")
+ .action((x, c) => c.copy(checkpointDir = Some(x)))
+ opt[Int]("checkpointInterval")
+ .text(s"how often to checkpoint the node Id cache, " +
+ s"default: ${defaultParams.checkpointInterval}")
+ .action((x, c) => c.copy(checkpointInterval = x))
+ opt[String]("testInput")
+ .text(s"input path to test dataset. If given, option fracTest is ignored." +
+ s" default: ${defaultParams.testInput}")
+ .action((x, c) => c.copy(testInput = x))
+ opt[String]("")
+ .text("data format: libsvm (default), dense (deprecated in Spark v1.1)")
+ .action((x, c) => c.copy(dataFormat = x))
+ arg[String]("")
+ .text("input path to labeled examples")
+ .required()
+ .action((x, c) => c.copy(input = x))
+ checkConfig { params =>
+ if (params.fracTest < 0 || params.fracTest >= 1) {
+ failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1).")
+ } else {
+ success
+ }
+ }
+ }
+
+ parser.parse(args, defaultParams).map { params =>
+ run(params)
+ }.getOrElse {
+ sys.exit(1)
+ }
+ }
+
+ def run(params: Params) {
+ val conf = new SparkConf().setAppName(s"RandomForestExample with $params")
+ val sc = new SparkContext(conf)
+ params.checkpointDir.foreach(sc.setCheckpointDir)
+ val algo = params.algo.toLowerCase
+
+ println(s"RandomForestExample with parameters:\n$params")
+
+ // Load training and test data and cache it.
+ val (training: DataFrame, test: DataFrame) = DecisionTreeExample.loadDatasets(sc, params.input,
+ params.dataFormat, params.testInput, algo, params.fracTest)
+
+ // Set up Pipeline
+ val stages = new mutable.ArrayBuffer[PipelineStage]()
+ // (1) For classification, re-index classes.
+ val labelColName = if (algo == "classification") "indexedLabel" else "label"
+ if (algo == "classification") {
+ val labelIndexer = new StringIndexer()
+ .setInputCol("labelString")
+ .setOutputCol(labelColName)
+ stages += labelIndexer
+ }
+ // (2) Identify categorical features using VectorIndexer.
+ // Features with more than maxCategories values will be treated as continuous.
+ val featuresIndexer = new VectorIndexer()
+ .setInputCol("features")
+ .setOutputCol("indexedFeatures")
+ .setMaxCategories(10)
+ stages += featuresIndexer
+ // (3) Learn Random Forest
+ val dt = algo match {
+ case "classification" =>
+ new RandomForestClassifier()
+ .setFeaturesCol("indexedFeatures")
+ .setLabelCol(labelColName)
+ .setMaxDepth(params.maxDepth)
+ .setMaxBins(params.maxBins)
+ .setMinInstancesPerNode(params.minInstancesPerNode)
+ .setMinInfoGain(params.minInfoGain)
+ .setCacheNodeIds(params.cacheNodeIds)
+ .setCheckpointInterval(params.checkpointInterval)
+ .setFeatureSubsetStrategy(params.featureSubsetStrategy)
+ .setNumTrees(params.numTrees)
+ case "regression" =>
+ new RandomForestRegressor()
+ .setFeaturesCol("indexedFeatures")
+ .setLabelCol(labelColName)
+ .setMaxDepth(params.maxDepth)
+ .setMaxBins(params.maxBins)
+ .setMinInstancesPerNode(params.minInstancesPerNode)
+ .setMinInfoGain(params.minInfoGain)
+ .setCacheNodeIds(params.cacheNodeIds)
+ .setCheckpointInterval(params.checkpointInterval)
+ .setFeatureSubsetStrategy(params.featureSubsetStrategy)
+ .setNumTrees(params.numTrees)
+ case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.")
+ }
+ stages += dt
+ val pipeline = new Pipeline().setStages(stages.toArray)
+
+ // Fit the Pipeline
+ val startTime = System.nanoTime()
+ val pipelineModel = pipeline.fit(training)
+ val elapsedTime = (System.nanoTime() - startTime) / 1e9
+ println(s"Training time: $elapsedTime seconds")
+
+ // Get the trained Random Forest from the fitted PipelineModel
+ algo match {
+ case "classification" =>
+ val rfModel = pipelineModel.getModel[RandomForestClassificationModel](
+ dt.asInstanceOf[RandomForestClassifier])
+ if (rfModel.totalNumNodes < 30) {
+ println(rfModel.toDebugString) // Print full model.
+ } else {
+ println(rfModel) // Print model summary.
+ }
+ case "regression" =>
+ val rfModel = pipelineModel.getModel[RandomForestRegressionModel](
+ dt.asInstanceOf[RandomForestRegressor])
+ if (rfModel.totalNumNodes < 30) {
+ println(rfModel.toDebugString) // Print full model.
+ } else {
+ println(rfModel) // Print model summary.
+ }
+ case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.")
+ }
+
+ // Evaluate model on training, test data
+ algo match {
+ case "classification" =>
+ println("Training data results:")
+ DecisionTreeExample.evaluateClassificationModel(pipelineModel, training, labelColName)
+ println("Test data results:")
+ DecisionTreeExample.evaluateClassificationModel(pipelineModel, test, labelColName)
+ case "regression" =>
+ println("Training data results:")
+ DecisionTreeExample.evaluateRegressionModel(pipelineModel, training, labelColName)
+ println("Test data results:")
+ DecisionTreeExample.evaluateRegressionModel(pipelineModel, test, labelColName)
+ case _ =>
+ throw new IllegalArgumentException("Algo ${params.algo} not supported.")
+ }
+
+ sc.stop()
+ }
+}
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala
index 431ead8c0c165..0763a7736305a 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala
@@ -25,6 +25,7 @@ import org.apache.spark.mllib.tree.GradientBoostedTrees
import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Algo}
import org.apache.spark.util.Utils
+
/**
* An example runner for Gradient Boosting using decision trees as weak learners. Run with
* {{{
diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala
index f40caad322f59..85b9a54b40baf 100644
--- a/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala
@@ -56,7 +56,7 @@ object MQTTPublisher {
while (true) {
try {
msgtopic.publish(message)
- println(s"Published data. topic: {msgtopic.getName()}; Message: {message}")
+ println(s"Published data. topic: ${msgtopic.getName()}; Message: $message")
} catch {
case e: MqttException if e.getReasonCode == MqttException.REASON_CODE_MAX_INFLIGHT =>
Thread.sleep(10)
diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdCMS.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdCMS.scala
index 62f49530edb12..c10de84a80ffe 100644
--- a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdCMS.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdCMS.scala
@@ -18,6 +18,7 @@
package org.apache.spark.examples.streaming
import com.twitter.algebird._
+import com.twitter.algebird.CMSHasherImplicits._
import org.apache.spark.SparkConf
import org.apache.spark.SparkContext._
@@ -67,7 +68,8 @@ object TwitterAlgebirdCMS {
val users = stream.map(status => status.getUser.getId)
- val cms = new CountMinSketchMonoid(EPS, DELTA, SEED, PERC)
+ // val cms = new CountMinSketchMonoid(EPS, DELTA, SEED, PERC)
+ val cms = TopPctCMS.monoid[Long](EPS, DELTA, SEED, PERC)
var globalCMS = cms.zero
val mm = new MapMonoid[Long, Int]()
var globalExact = Map[Long, Int]()
diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala
index 4d26b640e8d74..cca0fac0234e1 100644
--- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala
+++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala
@@ -31,7 +31,7 @@ import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.StreamingContext
import org.apache.spark.streaming.dstream._
import org.apache.spark.streaming.receiver.Receiver
-import org.apache.spark.util.Utils
+import org.apache.spark.util.ThreadUtils
/**
* Input stream that pulls messages from a Kafka Broker.
@@ -111,7 +111,8 @@ class KafkaReceiver[
val topicMessageStreams = consumerConnector.createMessageStreams(
topics, keyDecoder, valueDecoder)
- val executorPool = Utils.newDaemonFixedThreadPool(topics.values.sum, "KafkaMessageHandler")
+ val executorPool =
+ ThreadUtils.newDaemonFixedThreadPool(topics.values.sum, "KafkaMessageHandler")
try {
// Start the messages handler for each partition
topicMessageStreams.values.foreach { streams =>
diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala
index c4a44c1822c39..ea87e960379f1 100644
--- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala
+++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala
@@ -33,7 +33,7 @@ import org.I0Itec.zkclient.ZkClient
import org.apache.spark.{Logging, SparkEnv}
import org.apache.spark.storage.{StorageLevel, StreamBlockId}
import org.apache.spark.streaming.receiver.{BlockGenerator, BlockGeneratorListener, Receiver}
-import org.apache.spark.util.Utils
+import org.apache.spark.util.ThreadUtils
/**
* ReliableKafkaReceiver offers the ability to reliably store data into BlockManager without loss.
@@ -121,7 +121,7 @@ class ReliableKafkaReceiver[
zkClient = new ZkClient(consumerConfig.zkConnect, consumerConfig.zkSessionTimeoutMs,
consumerConfig.zkConnectionTimeoutMs, ZKStringSerializer)
- messageHandlerThreadPool = Utils.newDaemonFixedThreadPool(
+ messageHandlerThreadPool = ThreadUtils.newDaemonFixedThreadPool(
topics.values.sum, "KafkaMessageHandler")
blockGenerator.start()
diff --git a/launcher/pom.xml b/launcher/pom.xml
index 182e5f60218db..ebfa7685eaa18 100644
--- a/launcher/pom.xml
+++ b/launcher/pom.xml
@@ -68,6 +68,12 @@
org.apache.hadoophadoop-clienttest
+
+
+ org.codehaus.jackson
+ jackson-mapper-asl
+
+
diff --git a/make-distribution.sh b/make-distribution.sh
index 738a9c4d69601..cb65932b4abc0 100755
--- a/make-distribution.sh
+++ b/make-distribution.sh
@@ -32,7 +32,7 @@ SPARK_HOME="$(cd "`dirname "$0"`"; pwd)"
DISTDIR="$SPARK_HOME/dist"
SPARK_TACHYON=false
-TACHYON_VERSION="0.5.0"
+TACHYON_VERSION="0.6.4"
TACHYON_TGZ="tachyon-${TACHYON_VERSION}-bin.tar.gz"
TACHYON_URL="https://github.com/amplab/tachyon/releases/download/v${TACHYON_VERSION}/${TACHYON_TGZ}"
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Model.scala b/mllib/src/main/scala/org/apache/spark/ml/Model.scala
index cae5082b51196..a491bc7ee8295 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Model.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Model.scala
@@ -30,11 +30,13 @@ import org.apache.spark.ml.param.ParamMap
abstract class Model[M <: Model[M]] extends Transformer {
/**
* The parent estimator that produced this model.
+ * Note: For ensembles' component Models, this value can be null.
*/
val parent: Estimator[M]
/**
* Fitting parameters, such that parent.fit(..., fittingParamMap) could reproduce the model.
+ * Note: For ensembles' component Models, this value can be null.
*/
val fittingParamMap: ParamMap
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
index 7fb87fe452ee6..0acda71ec6045 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
@@ -94,7 +94,7 @@ private[ml] abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, O
throw new IllegalArgumentException(s"Output column ${map(outputCol)} already exists.")
}
val outputFields = schema.fields :+
- StructField(map(outputCol), outputDataType, !outputDataType.isPrimitive)
+ StructField(map(outputCol), outputDataType, nullable = false)
StructType(outputFields)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
index 3855e396b5534..ee2a8dc6db171 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -43,8 +43,7 @@ import org.apache.spark.sql.DataFrame
@AlphaComponent
final class DecisionTreeClassifier
extends Predictor[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel]
- with DecisionTreeParams
- with TreeClassifierParams {
+ with DecisionTreeParams with TreeClassifierParams {
// Override parameter setters from parent trait for Java API compatibility.
@@ -59,11 +58,9 @@ final class DecisionTreeClassifier
override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value)
- override def setCacheNodeIds(value: Boolean): this.type =
- super.setCacheNodeIds(value)
+ override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value)
- override def setCheckpointInterval(value: Int): this.type =
- super.setCheckpointInterval(value)
+ override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value)
override def setImpurity(value: String): this.type = super.setImpurity(value)
@@ -75,8 +72,9 @@ final class DecisionTreeClassifier
val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema(paramMap(labelCol))) match {
case Some(n: Int) => n
case None => throw new IllegalArgumentException("DecisionTreeClassifier was given input" +
- s" with invalid label column, without the number of classes specified.")
- // TODO: Automatically index labels.
+ s" with invalid label column ${paramMap(labelCol)}, without the number of classes" +
+ " specified. See StringIndexer.")
+ // TODO: Automatically index labels: SPARK-7126
}
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, paramMap)
val strategy = getOldStrategy(categoricalFeatures, numClasses)
@@ -85,18 +83,16 @@ final class DecisionTreeClassifier
}
/** (private[ml]) Create a Strategy instance to use with the old API. */
- override private[ml] def getOldStrategy(
+ private[ml] def getOldStrategy(
categoricalFeatures: Map[Int, Int],
numClasses: Int): OldStrategy = {
- val strategy = super.getOldStrategy(categoricalFeatures, numClasses)
- strategy.algo = OldAlgo.Classification
- strategy.setImpurity(getOldImpurity)
- strategy
+ super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity,
+ subsamplingRate = 1.0)
}
}
object DecisionTreeClassifier {
- /** Accessor for supported impurities */
+ /** Accessor for supported impurities: entropy, gini */
final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
new file mode 100644
index 0000000000000..d2e052fbbbf22
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -0,0 +1,228 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification
+
+import com.github.fommil.netlib.BLAS.{getInstance => blas}
+
+import org.apache.spark.Logging
+import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor}
+import org.apache.spark.ml.impl.tree._
+import org.apache.spark.ml.param.{Param, Params, ParamMap}
+import org.apache.spark.ml.regression.DecisionTreeRegressionModel
+import org.apache.spark.ml.tree.{DecisionTreeModel, TreeEnsembleModel}
+import org.apache.spark.ml.util.MetadataUtils
+import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.{GradientBoostedTrees => OldGBT}
+import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
+import org.apache.spark.mllib.tree.loss.{Loss => OldLoss, LogLoss => OldLogLoss}
+import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.DataFrame
+
+
+/**
+ * :: AlphaComponent ::
+ *
+ * [[http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)]]
+ * learning algorithm for classification.
+ * It supports binary labels, as well as both continuous and categorical features.
+ * Note: Multiclass labels are not currently supported.
+ */
+@AlphaComponent
+final class GBTClassifier
+ extends Predictor[Vector, GBTClassifier, GBTClassificationModel]
+ with GBTParams with TreeClassifierParams with Logging {
+
+ // Override parameter setters from parent trait for Java API compatibility.
+
+ // Parameters from TreeClassifierParams:
+
+ override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value)
+
+ override def setMaxBins(value: Int): this.type = super.setMaxBins(value)
+
+ override def setMinInstancesPerNode(value: Int): this.type =
+ super.setMinInstancesPerNode(value)
+
+ override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value)
+
+ override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value)
+
+ override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value)
+
+ override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value)
+
+ /**
+ * The impurity setting is ignored for GBT models.
+ * Individual trees are built using impurity "Variance."
+ */
+ override def setImpurity(value: String): this.type = {
+ logWarning("GBTClassifier.setImpurity should NOT be used")
+ this
+ }
+
+ // Parameters from TreeEnsembleParams:
+
+ override def setSubsamplingRate(value: Double): this.type = super.setSubsamplingRate(value)
+
+ override def setSeed(value: Long): this.type = {
+ logWarning("The 'seed' parameter is currently ignored by Gradient Boosting.")
+ super.setSeed(value)
+ }
+
+ // Parameters from GBTParams:
+
+ override def setMaxIter(value: Int): this.type = super.setMaxIter(value)
+
+ override def setStepSize(value: Double): this.type = super.setStepSize(value)
+
+ // Parameters for GBTClassifier:
+
+ /**
+ * Loss function which GBT tries to minimize. (case-insensitive)
+ * Supported: "logistic"
+ * (default = logistic)
+ * @group param
+ */
+ val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" +
+ " tries to minimize (case-insensitive). Supported options:" +
+ s" ${GBTClassifier.supportedLossTypes.mkString(", ")}")
+
+ setDefault(lossType -> "logistic")
+
+ /** @group setParam */
+ def setLossType(value: String): this.type = {
+ val lossStr = value.toLowerCase
+ require(GBTClassifier.supportedLossTypes.contains(lossStr), "GBTClassifier was given bad loss" +
+ s" type: $value. Supported options: ${GBTClassifier.supportedLossTypes.mkString(", ")}")
+ set(lossType, lossStr)
+ this
+ }
+
+ /** @group getParam */
+ def getLossType: String = getOrDefault(lossType)
+
+ /** (private[ml]) Convert new loss to old loss. */
+ override private[ml] def getOldLossType: OldLoss = {
+ getLossType match {
+ case "logistic" => OldLogLoss
+ case _ =>
+ // Should never happen because of check in setter method.
+ throw new RuntimeException(s"GBTClassifier was given bad loss type: $getLossType")
+ }
+ }
+
+ override protected def train(
+ dataset: DataFrame,
+ paramMap: ParamMap): GBTClassificationModel = {
+ val categoricalFeatures: Map[Int, Int] =
+ MetadataUtils.getCategoricalFeatures(dataset.schema(paramMap(featuresCol)))
+ val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema(paramMap(labelCol))) match {
+ case Some(n: Int) => n
+ case None => throw new IllegalArgumentException("GBTClassifier was given input" +
+ s" with invalid label column ${paramMap(labelCol)}, without the number of classes" +
+ " specified. See StringIndexer.")
+ // TODO: Automatically index labels: SPARK-7126
+ }
+ require(numClasses == 2,
+ s"GBTClassifier only supports binary classification but was given numClasses = $numClasses")
+ val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, paramMap)
+ val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification)
+ val oldGBT = new OldGBT(boostingStrategy)
+ val oldModel = oldGBT.run(oldDataset)
+ GBTClassificationModel.fromOld(oldModel, this, paramMap, categoricalFeatures)
+ }
+}
+
+object GBTClassifier {
+ // The losses below should be lowercase.
+ /** Accessor for supported loss settings: logistic */
+ final val supportedLossTypes: Array[String] = Array("logistic").map(_.toLowerCase)
+}
+
+/**
+ * :: AlphaComponent ::
+ *
+ * [[http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)]]
+ * model for classification.
+ * It supports binary labels, as well as both continuous and categorical features.
+ * Note: Multiclass labels are not currently supported.
+ * @param _trees Decision trees in the ensemble.
+ * @param _treeWeights Weights for the decision trees in the ensemble.
+ */
+@AlphaComponent
+final class GBTClassificationModel(
+ override val parent: GBTClassifier,
+ override val fittingParamMap: ParamMap,
+ private val _trees: Array[DecisionTreeRegressionModel],
+ private val _treeWeights: Array[Double])
+ extends PredictionModel[Vector, GBTClassificationModel]
+ with TreeEnsembleModel with Serializable {
+
+ require(numTrees > 0, "GBTClassificationModel requires at least 1 tree.")
+ require(_trees.length == _treeWeights.length, "GBTClassificationModel given trees, treeWeights" +
+ s" of non-matching lengths (${_trees.length}, ${_treeWeights.length}, respectively).")
+
+ override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
+
+ override def treeWeights: Array[Double] = _treeWeights
+
+ override protected def predict(features: Vector): Double = {
+ // TODO: Override transform() to broadcast model: SPARK-7127
+ // TODO: When we add a generic Boosting class, handle transform there? SPARK-7129
+ // Classifies by thresholding sum of weighted tree predictions
+ val treePredictions = _trees.map(_.rootNode.predict(features))
+ val prediction = blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1)
+ if (prediction > 0.0) 1.0 else 0.0
+ }
+
+ override protected def copy(): GBTClassificationModel = {
+ val m = new GBTClassificationModel(parent, fittingParamMap, _trees, _treeWeights)
+ Params.inheritValues(this.extractParamMap(), this, m)
+ m
+ }
+
+ override def toString: String = {
+ s"GBTClassificationModel with $numTrees trees"
+ }
+
+ /** (private[ml]) Convert to a model in the old API */
+ private[ml] def toOld: OldGBTModel = {
+ new OldGBTModel(OldAlgo.Classification, _trees.map(_.toOld), _treeWeights)
+ }
+}
+
+private[ml] object GBTClassificationModel {
+
+ /** (private[ml]) Convert a model from the old API */
+ def fromOld(
+ oldModel: OldGBTModel,
+ parent: GBTClassifier,
+ fittingParamMap: ParamMap,
+ categoricalFeatures: Map[Int, Int]): GBTClassificationModel = {
+ require(oldModel.algo == OldAlgo.Classification, "Cannot convert GradientBoostedTreesModel" +
+ s" with algo=${oldModel.algo} (old API) to GBTClassificationModel (new API).")
+ val newTrees = oldModel.trees.map { tree =>
+ // parent, fittingParamMap for each tree is null since there are no good ways to set these.
+ DecisionTreeRegressionModel.fromOld(tree, null, null, categoricalFeatures)
+ }
+ new GBTClassificationModel(parent, fittingParamMap, newTrees, oldModel.treeWeights)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
new file mode 100644
index 0000000000000..cfd6508fce890
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -0,0 +1,185 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification
+
+import scala.collection.mutable
+
+import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor}
+import org.apache.spark.ml.impl.tree._
+import org.apache.spark.ml.param.{Params, ParamMap}
+import org.apache.spark.ml.tree.{DecisionTreeModel, TreeEnsembleModel}
+import org.apache.spark.ml.util.MetadataUtils
+import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.{RandomForest => OldRandomForest}
+import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
+import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.DataFrame
+
+
+/**
+ * :: AlphaComponent ::
+ *
+ * [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] learning algorithm for
+ * classification.
+ * It supports both binary and multiclass labels, as well as both continuous and categorical
+ * features.
+ */
+@AlphaComponent
+final class RandomForestClassifier
+ extends Predictor[Vector, RandomForestClassifier, RandomForestClassificationModel]
+ with RandomForestParams with TreeClassifierParams {
+
+ // Override parameter setters from parent trait for Java API compatibility.
+
+ // Parameters from TreeClassifierParams:
+
+ override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value)
+
+ override def setMaxBins(value: Int): this.type = super.setMaxBins(value)
+
+ override def setMinInstancesPerNode(value: Int): this.type =
+ super.setMinInstancesPerNode(value)
+
+ override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value)
+
+ override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value)
+
+ override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value)
+
+ override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value)
+
+ override def setImpurity(value: String): this.type = super.setImpurity(value)
+
+ // Parameters from TreeEnsembleParams:
+
+ override def setSubsamplingRate(value: Double): this.type = super.setSubsamplingRate(value)
+
+ override def setSeed(value: Long): this.type = super.setSeed(value)
+
+ // Parameters from RandomForestParams:
+
+ override def setNumTrees(value: Int): this.type = super.setNumTrees(value)
+
+ override def setFeatureSubsetStrategy(value: String): this.type =
+ super.setFeatureSubsetStrategy(value)
+
+ override protected def train(
+ dataset: DataFrame,
+ paramMap: ParamMap): RandomForestClassificationModel = {
+ val categoricalFeatures: Map[Int, Int] =
+ MetadataUtils.getCategoricalFeatures(dataset.schema(paramMap(featuresCol)))
+ val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema(paramMap(labelCol))) match {
+ case Some(n: Int) => n
+ case None => throw new IllegalArgumentException("RandomForestClassifier was given input" +
+ s" with invalid label column ${paramMap(labelCol)}, without the number of classes" +
+ " specified. See StringIndexer.")
+ // TODO: Automatically index labels: SPARK-7126
+ }
+ val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, paramMap)
+ val strategy =
+ super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity)
+ val oldModel = OldRandomForest.trainClassifier(
+ oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed.toInt)
+ RandomForestClassificationModel.fromOld(oldModel, this, paramMap, categoricalFeatures)
+ }
+}
+
+object RandomForestClassifier {
+ /** Accessor for supported impurity settings: entropy, gini */
+ final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities
+
+ /** Accessor for supported featureSubsetStrategy settings: auto, all, onethird, sqrt, log2 */
+ final val supportedFeatureSubsetStrategies: Array[String] =
+ RandomForestParams.supportedFeatureSubsetStrategies
+}
+
+/**
+ * :: AlphaComponent ::
+ *
+ * [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] model for classification.
+ * It supports both binary and multiclass labels, as well as both continuous and categorical
+ * features.
+ * @param _trees Decision trees in the ensemble.
+ * Warning: These have null parents.
+ */
+@AlphaComponent
+final class RandomForestClassificationModel private[ml] (
+ override val parent: RandomForestClassifier,
+ override val fittingParamMap: ParamMap,
+ private val _trees: Array[DecisionTreeClassificationModel])
+ extends PredictionModel[Vector, RandomForestClassificationModel]
+ with TreeEnsembleModel with Serializable {
+
+ require(numTrees > 0, "RandomForestClassificationModel requires at least 1 tree.")
+
+ override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
+
+ // Note: We may add support for weights (based on tree performance) later on.
+ private lazy val _treeWeights: Array[Double] = Array.fill[Double](numTrees)(1.0)
+
+ override def treeWeights: Array[Double] = _treeWeights
+
+ override protected def predict(features: Vector): Double = {
+ // TODO: Override transform() to broadcast model. SPARK-7127
+ // TODO: When we add a generic Bagging class, handle transform there: SPARK-7128
+ // Classifies using majority votes.
+ // Ignore the weights since all are 1.0 for now.
+ val votes = mutable.Map.empty[Int, Double]
+ _trees.view.foreach { tree =>
+ val prediction = tree.rootNode.predict(features).toInt
+ votes(prediction) = votes.getOrElse(prediction, 0.0) + 1.0 // 1.0 = weight
+ }
+ votes.maxBy(_._2)._1
+ }
+
+ override protected def copy(): RandomForestClassificationModel = {
+ val m = new RandomForestClassificationModel(parent, fittingParamMap, _trees)
+ Params.inheritValues(this.extractParamMap(), this, m)
+ m
+ }
+
+ override def toString: String = {
+ s"RandomForestClassificationModel with $numTrees trees"
+ }
+
+ /** (private[ml]) Convert to a model in the old API */
+ private[ml] def toOld: OldRandomForestModel = {
+ new OldRandomForestModel(OldAlgo.Classification, _trees.map(_.toOld))
+ }
+}
+
+private[ml] object RandomForestClassificationModel {
+
+ /** (private[ml]) Convert a model from the old API */
+ def fromOld(
+ oldModel: OldRandomForestModel,
+ parent: RandomForestClassifier,
+ fittingParamMap: ParamMap,
+ categoricalFeatures: Map[Int, Int]): RandomForestClassificationModel = {
+ require(oldModel.algo == OldAlgo.Classification, "Cannot convert RandomForestModel" +
+ s" with algo=${oldModel.algo} (old API) to RandomForestClassificationModel (new API).")
+ val newTrees = oldModel.trees.map { tree =>
+ // parent, fittingParamMap for each tree is null since there are no good ways to set these.
+ DecisionTreeClassificationModel.fromOld(tree, null, null, categoricalFeatures)
+ }
+ new RandomForestClassificationModel(parent, fittingParamMap, newTrees)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
new file mode 100644
index 0000000000000..e6a62d998bb97
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
@@ -0,0 +1,116 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml._
+import org.apache.spark.ml.param._
+import org.apache.spark.ml.param.shared._
+import org.apache.spark.ml.util.SchemaUtils
+import org.apache.spark.mllib.feature
+import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
+import org.apache.spark.sql._
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types.StructType
+
+/**
+ * Params for [[IDF]] and [[IDFModel]].
+ */
+private[feature] trait IDFBase extends Params with HasInputCol with HasOutputCol {
+
+ /**
+ * The minimum of documents in which a term should appear.
+ * @group param
+ */
+ final val minDocFreq = new IntParam(
+ this, "minDocFreq", "minimum of documents in which a term should appear for filtering")
+
+ setDefault(minDocFreq -> 0)
+
+ /** @group getParam */
+ def getMinDocFreq: Int = getOrDefault(minDocFreq)
+
+ /** @group setParam */
+ def setMinDocFreq(value: Int): this.type = set(minDocFreq, value)
+
+ /**
+ * Validate and transform the input schema.
+ */
+ protected def validateAndTransformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ val map = extractParamMap(paramMap)
+ SchemaUtils.checkColumnType(schema, map(inputCol), new VectorUDT)
+ SchemaUtils.appendColumn(schema, map(outputCol), new VectorUDT)
+ }
+}
+
+/**
+ * :: AlphaComponent ::
+ * Compute the Inverse Document Frequency (IDF) given a collection of documents.
+ */
+@AlphaComponent
+final class IDF extends Estimator[IDFModel] with IDFBase {
+
+ /** @group setParam */
+ def setInputCol(value: String): this.type = set(inputCol, value)
+
+ /** @group setParam */
+ def setOutputCol(value: String): this.type = set(outputCol, value)
+
+ override def fit(dataset: DataFrame, paramMap: ParamMap): IDFModel = {
+ transformSchema(dataset.schema, paramMap, logging = true)
+ val map = extractParamMap(paramMap)
+ val input = dataset.select(map(inputCol)).map { case Row(v: Vector) => v }
+ val idf = new feature.IDF(map(minDocFreq)).fit(input)
+ val model = new IDFModel(this, map, idf)
+ Params.inheritValues(map, this, model)
+ model
+ }
+
+ override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ validateAndTransformSchema(schema, paramMap)
+ }
+}
+
+/**
+ * :: AlphaComponent ::
+ * Model fitted by [[IDF]].
+ */
+@AlphaComponent
+class IDFModel private[ml] (
+ override val parent: IDF,
+ override val fittingParamMap: ParamMap,
+ idfModel: feature.IDFModel)
+ extends Model[IDFModel] with IDFBase {
+
+ /** @group setParam */
+ def setInputCol(value: String): this.type = set(inputCol, value)
+
+ /** @group setParam */
+ def setOutputCol(value: String): this.type = set(outputCol, value)
+
+ override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
+ transformSchema(dataset.schema, paramMap, logging = true)
+ val map = extractParamMap(paramMap)
+ val idf = udf { vec: Vector => idfModel.transform(vec) }
+ dataset.withColumn(map(outputCol), idf(col(map(inputCol))))
+ }
+
+ override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ validateAndTransformSchema(schema, paramMap)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala
new file mode 100644
index 0000000000000..d855f04799ae7
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala
@@ -0,0 +1,171 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import scala.collection.mutable
+
+import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml.UnaryTransformer
+import org.apache.spark.ml.param.{IntParam, ParamMap}
+import org.apache.spark.mllib.linalg._
+import org.apache.spark.sql.types.DataType
+
+/**
+ * :: AlphaComponent ::
+ * Perform feature expansion in a polynomial space. As said in wikipedia of Polynomial Expansion,
+ * which is available at [[http://en.wikipedia.org/wiki/Polynomial_expansion]], "In mathematics, an
+ * expansion of a product of sums expresses it as a sum of products by using the fact that
+ * multiplication distributes over addition". Take a 2-variable feature vector as an example:
+ * `(x, y)`, if we want to expand it with degree 2, then we get `(x, y, x * x, x * y, y * y)`.
+ */
+@AlphaComponent
+class PolynomialExpansion extends UnaryTransformer[Vector, Vector, PolynomialExpansion] {
+
+ /**
+ * The polynomial degree to expand, which should be larger than 1.
+ * @group param
+ */
+ val degree = new IntParam(this, "degree", "the polynomial degree to expand")
+ setDefault(degree -> 2)
+
+ /** @group getParam */
+ def getDegree: Int = getOrDefault(degree)
+
+ /** @group setParam */
+ def setDegree(value: Int): this.type = set(degree, value)
+
+ override protected def createTransformFunc(paramMap: ParamMap): Vector => Vector = { v =>
+ val d = paramMap(degree)
+ PolynomialExpansion.expand(v, d)
+ }
+
+ override protected def outputDataType: DataType = new VectorUDT()
+}
+
+/**
+ * The expansion is done via recursion. Given n features and degree d, the size after expansion is
+ * (n + d choose d) (including 1 and first-order values). For example, let f([a, b, c], 3) be the
+ * function that expands [a, b, c] to their monomials of degree 3. We have the following recursion:
+ *
+ * {{{
+ * f([a, b, c], 3) = f([a, b], 3) ++ f([a, b], 2) * c ++ f([a, b], 1) * c^2 ++ [c^3]
+ * }}}
+ *
+ * To handle sparsity, if c is zero, we can skip all monomials that contain it. We remember the
+ * current index and increment it properly for sparse input.
+ */
+object PolynomialExpansion {
+
+ private def choose(n: Int, k: Int): Int = {
+ Range(n, n - k, -1).product / Range(k, 1, -1).product
+ }
+
+ private def getPolySize(numFeatures: Int, degree: Int): Int = choose(numFeatures + degree, degree)
+
+ private def expandDense(
+ values: Array[Double],
+ lastIdx: Int,
+ degree: Int,
+ multiplier: Double,
+ polyValues: Array[Double],
+ curPolyIdx: Int): Int = {
+ if (multiplier == 0.0) {
+ // do nothing
+ } else if (degree == 0 || lastIdx < 0) {
+ if (curPolyIdx >= 0) { // skip the very first 1
+ polyValues(curPolyIdx) = multiplier
+ }
+ } else {
+ val v = values(lastIdx)
+ val lastIdx1 = lastIdx - 1
+ var alpha = multiplier
+ var i = 0
+ var curStart = curPolyIdx
+ while (i <= degree && alpha != 0.0) {
+ curStart = expandDense(values, lastIdx1, degree - i, alpha, polyValues, curStart)
+ i += 1
+ alpha *= v
+ }
+ }
+ curPolyIdx + getPolySize(lastIdx + 1, degree)
+ }
+
+ private def expandSparse(
+ indices: Array[Int],
+ values: Array[Double],
+ lastIdx: Int,
+ lastFeatureIdx: Int,
+ degree: Int,
+ multiplier: Double,
+ polyIndices: mutable.ArrayBuilder[Int],
+ polyValues: mutable.ArrayBuilder[Double],
+ curPolyIdx: Int): Int = {
+ if (multiplier == 0.0) {
+ // do nothing
+ } else if (degree == 0 || lastIdx < 0) {
+ if (curPolyIdx >= 0) { // skip the very first 1
+ polyIndices += curPolyIdx
+ polyValues += multiplier
+ }
+ } else {
+ // Skip all zeros at the tail.
+ val v = values(lastIdx)
+ val lastIdx1 = lastIdx - 1
+ val lastFeatureIdx1 = indices(lastIdx) - 1
+ var alpha = multiplier
+ var curStart = curPolyIdx
+ var i = 0
+ while (i <= degree && alpha != 0.0) {
+ curStart = expandSparse(indices, values, lastIdx1, lastFeatureIdx1, degree - i, alpha,
+ polyIndices, polyValues, curStart)
+ i += 1
+ alpha *= v
+ }
+ }
+ curPolyIdx + getPolySize(lastFeatureIdx + 1, degree)
+ }
+
+ private def expand(dv: DenseVector, degree: Int): DenseVector = {
+ val n = dv.size
+ val polySize = getPolySize(n, degree)
+ val polyValues = new Array[Double](polySize - 1)
+ expandDense(dv.values, n - 1, degree, 1.0, polyValues, -1)
+ new DenseVector(polyValues)
+ }
+
+ private def expand(sv: SparseVector, degree: Int): SparseVector = {
+ val polySize = getPolySize(sv.size, degree)
+ val nnz = sv.values.length
+ val nnzPolySize = getPolySize(nnz, degree)
+ val polyIndices = mutable.ArrayBuilder.make[Int]
+ polyIndices.sizeHint(nnzPolySize - 1)
+ val polyValues = mutable.ArrayBuilder.make[Double]
+ polyValues.sizeHint(nnzPolySize - 1)
+ expandSparse(
+ sv.indices, sv.values, nnz - 1, sv.size - 1, degree, 1.0, polyIndices, polyValues, -1)
+ new SparseVector(polySize - 1, polyIndices.result(), polyValues.result())
+ }
+
+ def expand(v: Vector, degree: Int): Vector = {
+ v match {
+ case dv: DenseVector => expand(dv, degree)
+ case sv: SparseVector => expand(sv, degree)
+ case _ => throw new IllegalArgumentException
+ }
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
index e567e069e7c0b..7b2a451ca5ee5 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
@@ -55,7 +55,8 @@ class VectorAssembler extends Transformer with HasInputCols with HasOutputCol {
schema(c).dataType match {
case DoubleType => UnresolvedAttribute(c)
case t if t.isInstanceOf[VectorUDT] => UnresolvedAttribute(c)
- case _: NativeType => Alias(Cast(UnresolvedAttribute(c), DoubleType), s"${c}_double_$uid")()
+ case _: NumericType | BooleanType =>
+ Alias(Cast(UnresolvedAttribute(c), DoubleType), s"${c}_double_$uid")()
}
}
dataset.select(col("*"), assembleFunc(new Column(CreateStruct(args))).as(map(outputCol)))
@@ -67,7 +68,7 @@ class VectorAssembler extends Transformer with HasInputCols with HasOutputCol {
val outputColName = map(outputCol)
val inputDataTypes = inputColNames.map(name => schema(name).dataType)
inputDataTypes.foreach {
- case _: NativeType =>
+ case _: NumericType | BooleanType =>
case t if t.isInstanceOf[VectorUDT] =>
case other =>
throw new IllegalArgumentException(s"Data type $other is not supported.")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala
index 6f4509f03d033..ab6281b9b2e34 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala
@@ -20,9 +20,12 @@ package org.apache.spark.ml.impl.tree
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.ml.impl.estimator.PredictorParams
import org.apache.spark.ml.param._
-import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
+import org.apache.spark.ml.param.shared.{HasSeed, HasMaxIter}
+import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo,
+ BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy}
import org.apache.spark.mllib.tree.impurity.{Gini => OldGini, Entropy => OldEntropy,
Impurity => OldImpurity, Variance => OldVariance}
+import org.apache.spark.mllib.tree.loss.{Loss => OldLoss}
/**
@@ -117,79 +120,68 @@ private[ml] trait DecisionTreeParams extends PredictorParams {
def setMaxDepth(value: Int): this.type = {
require(value >= 0, s"maxDepth parameter must be >= 0. Given bad value: $value")
set(maxDepth, value)
- this.asInstanceOf[this.type]
}
/** @group getParam */
- def getMaxDepth: Int = getOrDefault(maxDepth)
+ final def getMaxDepth: Int = getOrDefault(maxDepth)
/** @group setParam */
def setMaxBins(value: Int): this.type = {
require(value >= 2, s"maxBins parameter must be >= 2. Given bad value: $value")
set(maxBins, value)
- this
}
/** @group getParam */
- def getMaxBins: Int = getOrDefault(maxBins)
+ final def getMaxBins: Int = getOrDefault(maxBins)
/** @group setParam */
def setMinInstancesPerNode(value: Int): this.type = {
require(value >= 1, s"minInstancesPerNode parameter must be >= 1. Given bad value: $value")
set(minInstancesPerNode, value)
- this
}
/** @group getParam */
- def getMinInstancesPerNode: Int = getOrDefault(minInstancesPerNode)
+ final def getMinInstancesPerNode: Int = getOrDefault(minInstancesPerNode)
/** @group setParam */
- def setMinInfoGain(value: Double): this.type = {
- set(minInfoGain, value)
- this
- }
+ def setMinInfoGain(value: Double): this.type = set(minInfoGain, value)
/** @group getParam */
- def getMinInfoGain: Double = getOrDefault(minInfoGain)
+ final def getMinInfoGain: Double = getOrDefault(minInfoGain)
/** @group expertSetParam */
def setMaxMemoryInMB(value: Int): this.type = {
require(value > 0, s"maxMemoryInMB parameter must be > 0. Given bad value: $value")
set(maxMemoryInMB, value)
- this
}
/** @group expertGetParam */
- def getMaxMemoryInMB: Int = getOrDefault(maxMemoryInMB)
+ final def getMaxMemoryInMB: Int = getOrDefault(maxMemoryInMB)
/** @group expertSetParam */
- def setCacheNodeIds(value: Boolean): this.type = {
- set(cacheNodeIds, value)
- this
- }
+ def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value)
/** @group expertGetParam */
- def getCacheNodeIds: Boolean = getOrDefault(cacheNodeIds)
+ final def getCacheNodeIds: Boolean = getOrDefault(cacheNodeIds)
/** @group expertSetParam */
def setCheckpointInterval(value: Int): this.type = {
require(value >= 1, s"checkpointInterval parameter must be >= 1. Given bad value: $value")
set(checkpointInterval, value)
- this
}
/** @group expertGetParam */
- def getCheckpointInterval: Int = getOrDefault(checkpointInterval)
+ final def getCheckpointInterval: Int = getOrDefault(checkpointInterval)
- /**
- * Create a Strategy instance to use with the old API.
- * NOTE: The caller should set impurity and subsamplingRate (which is set to 1.0,
- * the default for single trees).
- */
+ /** (private[ml]) Create a Strategy instance to use with the old API. */
private[ml] def getOldStrategy(
categoricalFeatures: Map[Int, Int],
- numClasses: Int): OldStrategy = {
- val strategy = OldStrategy.defaultStategy(OldAlgo.Classification)
+ numClasses: Int,
+ oldAlgo: OldAlgo.Algo,
+ oldImpurity: OldImpurity,
+ subsamplingRate: Double): OldStrategy = {
+ val strategy = OldStrategy.defaultStategy(oldAlgo)
+ strategy.impurity = oldImpurity
strategy.checkpointInterval = getCheckpointInterval
strategy.maxBins = getMaxBins
strategy.maxDepth = getMaxDepth
@@ -199,13 +191,13 @@ private[ml] trait DecisionTreeParams extends PredictorParams {
strategy.useNodeIdCache = getCacheNodeIds
strategy.numClasses = numClasses
strategy.categoricalFeaturesInfo = categoricalFeatures
- strategy.subsamplingRate = 1.0 // default for individual trees
+ strategy.subsamplingRate = subsamplingRate
strategy
}
}
/**
- * (private trait) Parameters for Decision Tree-based classification algorithms.
+ * Parameters for Decision Tree-based classification algorithms.
*/
private[ml] trait TreeClassifierParams extends Params {
@@ -215,7 +207,7 @@ private[ml] trait TreeClassifierParams extends Params {
* (default = gini)
* @group param
*/
- val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" +
+ final val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" +
" information gain calculation (case-insensitive). Supported options:" +
s" ${TreeClassifierParams.supportedImpurities.mkString(", ")}")
@@ -228,11 +220,10 @@ private[ml] trait TreeClassifierParams extends Params {
s"Tree-based classifier was given unrecognized impurity: $value." +
s" Supported options: ${TreeClassifierParams.supportedImpurities.mkString(", ")}")
set(impurity, impurityStr)
- this
}
/** @group getParam */
- def getImpurity: String = getOrDefault(impurity)
+ final def getImpurity: String = getOrDefault(impurity)
/** Convert new impurity to old impurity. */
private[ml] def getOldImpurity: OldImpurity = {
@@ -249,11 +240,11 @@ private[ml] trait TreeClassifierParams extends Params {
private[ml] object TreeClassifierParams {
// These options should be lowercase.
- val supportedImpurities: Array[String] = Array("entropy", "gini").map(_.toLowerCase)
+ final val supportedImpurities: Array[String] = Array("entropy", "gini").map(_.toLowerCase)
}
/**
- * (private trait) Parameters for Decision Tree-based regression algorithms.
+ * Parameters for Decision Tree-based regression algorithms.
*/
private[ml] trait TreeRegressorParams extends Params {
@@ -263,7 +254,7 @@ private[ml] trait TreeRegressorParams extends Params {
* (default = variance)
* @group param
*/
- val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" +
+ final val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" +
" information gain calculation (case-insensitive). Supported options:" +
s" ${TreeRegressorParams.supportedImpurities.mkString(", ")}")
@@ -276,14 +267,13 @@ private[ml] trait TreeRegressorParams extends Params {
s"Tree-based regressor was given unrecognized impurity: $value." +
s" Supported options: ${TreeRegressorParams.supportedImpurities.mkString(", ")}")
set(impurity, impurityStr)
- this
}
/** @group getParam */
- def getImpurity: String = getOrDefault(impurity)
+ final def getImpurity: String = getOrDefault(impurity)
/** Convert new impurity to old impurity. */
- protected def getOldImpurity: OldImpurity = {
+ private[ml] def getOldImpurity: OldImpurity = {
getImpurity match {
case "variance" => OldVariance
case _ =>
@@ -296,5 +286,186 @@ private[ml] trait TreeRegressorParams extends Params {
private[ml] object TreeRegressorParams {
// These options should be lowercase.
- val supportedImpurities: Array[String] = Array("variance").map(_.toLowerCase)
+ final val supportedImpurities: Array[String] = Array("variance").map(_.toLowerCase)
+}
+
+/**
+ * :: DeveloperApi ::
+ * Parameters for Decision Tree-based ensemble algorithms.
+ *
+ * Note: Marked as private and DeveloperApi since this may be made public in the future.
+ */
+@DeveloperApi
+private[ml] trait TreeEnsembleParams extends DecisionTreeParams with HasSeed {
+
+ /**
+ * Fraction of the training data used for learning each decision tree.
+ * (default = 1.0)
+ * @group param
+ */
+ final val subsamplingRate: DoubleParam = new DoubleParam(this, "subsamplingRate",
+ "Fraction of the training data used for learning each decision tree.")
+
+ setDefault(subsamplingRate -> 1.0)
+
+ /** @group setParam */
+ def setSubsamplingRate(value: Double): this.type = {
+ require(value > 0.0 && value <= 1.0,
+ s"Subsampling rate must be in range (0,1]. Bad rate: $value")
+ set(subsamplingRate, value)
+ }
+
+ /** @group getParam */
+ final def getSubsamplingRate: Double = getOrDefault(subsamplingRate)
+
+ /** @group setParam */
+ def setSeed(value: Long): this.type = set(seed, value)
+
+ /**
+ * Create a Strategy instance to use with the old API.
+ * NOTE: The caller should set impurity and seed.
+ */
+ private[ml] def getOldStrategy(
+ categoricalFeatures: Map[Int, Int],
+ numClasses: Int,
+ oldAlgo: OldAlgo.Algo,
+ oldImpurity: OldImpurity): OldStrategy = {
+ super.getOldStrategy(categoricalFeatures, numClasses, oldAlgo, oldImpurity, getSubsamplingRate)
+ }
+}
+
+/**
+ * :: DeveloperApi ::
+ * Parameters for Random Forest algorithms.
+ *
+ * Note: Marked as private and DeveloperApi since this may be made public in the future.
+ */
+@DeveloperApi
+private[ml] trait RandomForestParams extends TreeEnsembleParams {
+
+ /**
+ * Number of trees to train (>= 1).
+ * If 1, then no bootstrapping is used. If > 1, then bootstrapping is done.
+ * TODO: Change to always do bootstrapping (simpler). SPARK-7130
+ * (default = 20)
+ * @group param
+ */
+ final val numTrees: IntParam = new IntParam(this, "numTrees", "Number of trees to train (>= 1)")
+
+ /**
+ * The number of features to consider for splits at each tree node.
+ * Supported options:
+ * - "auto": Choose automatically for task:
+ * If numTrees == 1, set to "all."
+ * If numTrees > 1 (forest), set to "sqrt" for classification and
+ * to "onethird" for regression.
+ * - "all": use all features
+ * - "onethird": use 1/3 of the features
+ * - "sqrt": use sqrt(number of features)
+ * - "log2": use log2(number of features)
+ * (default = "auto")
+ *
+ * These various settings are based on the following references:
+ * - log2: tested in Breiman (2001)
+ * - sqrt: recommended by Breiman manual for random forests
+ * - The defaults of sqrt (classification) and onethird (regression) match the R randomForest
+ * package.
+ * @see [[http://www.stat.berkeley.edu/~breiman/randomforest2001.pdf Breiman (2001)]]
+ * @see [[http://www.stat.berkeley.edu/~breiman/Using_random_forests_V3.1.pdf Breiman manual for
+ * random forests]]
+ *
+ * @group param
+ */
+ final val featureSubsetStrategy: Param[String] = new Param[String](this, "featureSubsetStrategy",
+ "The number of features to consider for splits at each tree node." +
+ s" Supported options: ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}")
+
+ setDefault(numTrees -> 20, featureSubsetStrategy -> "auto")
+
+ /** @group setParam */
+ def setNumTrees(value: Int): this.type = {
+ require(value >= 1, s"Random Forest numTrees parameter cannot be $value; it must be >= 1.")
+ set(numTrees, value)
+ }
+
+ /** @group getParam */
+ final def getNumTrees: Int = getOrDefault(numTrees)
+
+ /** @group setParam */
+ def setFeatureSubsetStrategy(value: String): this.type = {
+ val strategyStr = value.toLowerCase
+ require(RandomForestParams.supportedFeatureSubsetStrategies.contains(strategyStr),
+ s"RandomForestParams was given unrecognized featureSubsetStrategy: $value. Supported" +
+ s" options: ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}")
+ set(featureSubsetStrategy, strategyStr)
+ }
+
+ /** @group getParam */
+ final def getFeatureSubsetStrategy: String = getOrDefault(featureSubsetStrategy)
+}
+
+private[ml] object RandomForestParams {
+ // These options should be lowercase.
+ final val supportedFeatureSubsetStrategies: Array[String] =
+ Array("auto", "all", "onethird", "sqrt", "log2").map(_.toLowerCase)
+}
+
+/**
+ * :: DeveloperApi ::
+ * Parameters for Gradient-Boosted Tree algorithms.
+ *
+ * Note: Marked as private and DeveloperApi since this may be made public in the future.
+ */
+@DeveloperApi
+private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter {
+
+ /**
+ * Step size (a.k.a. learning rate) in interval (0, 1] for shrinking the contribution of each
+ * estimator.
+ * (default = 0.1)
+ * @group param
+ */
+ final val stepSize: DoubleParam = new DoubleParam(this, "stepSize", "Step size (a.k.a." +
+ " learning rate) in interval (0, 1] for shrinking the contribution of each estimator")
+
+ /* TODO: Add this doc when we add this param. SPARK-7132
+ * Threshold for stopping early when runWithValidation is used.
+ * If the error rate on the validation input changes by less than the validationTol,
+ * then learning will stop early (before [[numIterations]]).
+ * This parameter is ignored when run is used.
+ * (default = 1e-5)
+ * @group param
+ */
+ // final val validationTol: DoubleParam = new DoubleParam(this, "validationTol", "")
+ // validationTol -> 1e-5
+
+ setDefault(maxIter -> 20, stepSize -> 0.1)
+
+ /** @group setParam */
+ def setMaxIter(value: Int): this.type = {
+ require(value >= 1, s"Gradient Boosting maxIter parameter cannot be $value; it must be >= 1.")
+ set(maxIter, value)
+ }
+
+ /** @group setParam */
+ def setStepSize(value: Double): this.type = {
+ require(value > 0.0 && value <= 1.0,
+ s"GBT given invalid step size ($value). Value should be in (0,1].")
+ set(stepSize, value)
+ }
+
+ /** @group getParam */
+ final def getStepSize: Double = getOrDefault(stepSize)
+
+ /** (private[ml]) Create a BoostingStrategy instance to use with the old API. */
+ private[ml] def getOldBoostingStrategy(
+ categoricalFeatures: Map[Int, Int],
+ oldAlgo: OldAlgo.Algo): OldBoostingStrategy = {
+ val strategy = super.getOldStrategy(categoricalFeatures, numClasses = 2, oldAlgo, OldVariance)
+ // NOTE: The old API does not support "seed" so we ignore it.
+ new OldBoostingStrategy(strategy, getOldLossType, getMaxIter, getStepSize)
+ }
+
+ /** Get old Gradient Boosting Loss type */
+ private[ml] def getOldLossType: OldLoss
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
index 95d7e64790c79..e88c48741e99f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
@@ -45,7 +45,8 @@ private[shared] object SharedParamsCodeGen {
ParamDesc[Array[String]]("inputCols", "input column names"),
ParamDesc[String]("outputCol", "output column name"),
ParamDesc[Int]("checkpointInterval", "checkpoint interval"),
- ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true")))
+ ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true")),
+ ParamDesc[Long]("seed", "random seed", Some("Utils.random.nextLong()")))
val code = genSharedParams(params)
val file = "src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala"
@@ -154,6 +155,7 @@ private[shared] object SharedParamsCodeGen {
|
|import org.apache.spark.annotation.DeveloperApi
|import org.apache.spark.ml.param._
+ |import org.apache.spark.util.Utils
|
|// DO NOT MODIFY THIS FILE! It was generated by SharedParamsCodeGen.
|
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
index 72b08bf276483..a860b8834cff9 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
@@ -19,6 +19,7 @@ package org.apache.spark.ml.param.shared
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.ml.param._
+import org.apache.spark.util.Utils
// DO NOT MODIFY THIS FILE! It was generated by SharedParamsCodeGen.
@@ -256,4 +257,23 @@ trait HasFitIntercept extends Params {
/** @group getParam */
final def getFitIntercept: Boolean = getOrDefault(fitIntercept)
}
+
+/**
+ * :: DeveloperApi ::
+ * Trait for shared param seed (default: Utils.random.nextLong()).
+ */
+@DeveloperApi
+trait HasSeed extends Params {
+
+ /**
+ * Param for random seed.
+ * @group param
+ */
+ final val seed: LongParam = new LongParam(this, "seed", "random seed")
+
+ setDefault(seed, Utils.random.nextLong())
+
+ /** @group getParam */
+ final def getSeed: Long = getOrDefault(seed)
+}
// scalastyle:on
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
index 49a8b77acf960..756725a64b0f3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
@@ -42,8 +42,7 @@ import org.apache.spark.sql.DataFrame
@AlphaComponent
final class DecisionTreeRegressor
extends Predictor[Vector, DecisionTreeRegressor, DecisionTreeRegressionModel]
- with DecisionTreeParams
- with TreeRegressorParams {
+ with DecisionTreeParams with TreeRegressorParams {
// Override parameter setters from parent trait for Java API compatibility.
@@ -60,8 +59,7 @@ final class DecisionTreeRegressor
override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value)
- override def setCheckpointInterval(value: Int): this.type =
- super.setCheckpointInterval(value)
+ override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value)
override def setImpurity(value: String): this.type = super.setImpurity(value)
@@ -78,15 +76,13 @@ final class DecisionTreeRegressor
/** (private[ml]) Create a Strategy instance to use with the old API. */
private[ml] def getOldStrategy(categoricalFeatures: Map[Int, Int]): OldStrategy = {
- val strategy = super.getOldStrategy(categoricalFeatures, numClasses = 0)
- strategy.algo = OldAlgo.Regression
- strategy.setImpurity(getOldImpurity)
- strategy
+ super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity,
+ subsamplingRate = 1.0)
}
}
object DecisionTreeRegressor {
- /** Accessor for supported impurities */
+ /** Accessor for supported impurities: variance */
final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
new file mode 100644
index 0000000000000..c784cf39ed31a
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
@@ -0,0 +1,218 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.regression
+
+import com.github.fommil.netlib.BLAS.{getInstance => blas}
+
+import org.apache.spark.Logging
+import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor}
+import org.apache.spark.ml.impl.tree._
+import org.apache.spark.ml.param.{Params, ParamMap, Param}
+import org.apache.spark.ml.tree.{DecisionTreeModel, TreeEnsembleModel}
+import org.apache.spark.ml.util.MetadataUtils
+import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.{GradientBoostedTrees => OldGBT}
+import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
+import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, Loss => OldLoss,
+ SquaredError => OldSquaredError}
+import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.DataFrame
+
+
+/**
+ * :: AlphaComponent ::
+ *
+ * [[http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)]]
+ * learning algorithm for regression.
+ * It supports both continuous and categorical features.
+ */
+@AlphaComponent
+final class GBTRegressor
+ extends Predictor[Vector, GBTRegressor, GBTRegressionModel]
+ with GBTParams with TreeRegressorParams with Logging {
+
+ // Override parameter setters from parent trait for Java API compatibility.
+
+ // Parameters from TreeRegressorParams:
+
+ override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value)
+
+ override def setMaxBins(value: Int): this.type = super.setMaxBins(value)
+
+ override def setMinInstancesPerNode(value: Int): this.type =
+ super.setMinInstancesPerNode(value)
+
+ override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value)
+
+ override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value)
+
+ override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value)
+
+ override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value)
+
+ /**
+ * The impurity setting is ignored for GBT models.
+ * Individual trees are built using impurity "Variance."
+ */
+ override def setImpurity(value: String): this.type = {
+ logWarning("GBTRegressor.setImpurity should NOT be used")
+ this
+ }
+
+ // Parameters from TreeEnsembleParams:
+
+ override def setSubsamplingRate(value: Double): this.type = super.setSubsamplingRate(value)
+
+ override def setSeed(value: Long): this.type = {
+ logWarning("The 'seed' parameter is currently ignored by Gradient Boosting.")
+ super.setSeed(value)
+ }
+
+ // Parameters from GBTParams:
+
+ override def setMaxIter(value: Int): this.type = super.setMaxIter(value)
+
+ override def setStepSize(value: Double): this.type = super.setStepSize(value)
+
+ // Parameters for GBTRegressor:
+
+ /**
+ * Loss function which GBT tries to minimize. (case-insensitive)
+ * Supported: "squared" (L2) and "absolute" (L1)
+ * (default = squared)
+ * @group param
+ */
+ val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" +
+ " tries to minimize (case-insensitive). Supported options:" +
+ s" ${GBTRegressor.supportedLossTypes.mkString(", ")}")
+
+ setDefault(lossType -> "squared")
+
+ /** @group setParam */
+ def setLossType(value: String): this.type = {
+ val lossStr = value.toLowerCase
+ require(GBTRegressor.supportedLossTypes.contains(lossStr), "GBTRegressor was given bad loss" +
+ s" type: $value. Supported options: ${GBTRegressor.supportedLossTypes.mkString(", ")}")
+ set(lossType, lossStr)
+ this
+ }
+
+ /** @group getParam */
+ def getLossType: String = getOrDefault(lossType)
+
+ /** (private[ml]) Convert new loss to old loss. */
+ override private[ml] def getOldLossType: OldLoss = {
+ getLossType match {
+ case "squared" => OldSquaredError
+ case "absolute" => OldAbsoluteError
+ case _ =>
+ // Should never happen because of check in setter method.
+ throw new RuntimeException(s"GBTRegressorParams was given bad loss type: $getLossType")
+ }
+ }
+
+ override protected def train(
+ dataset: DataFrame,
+ paramMap: ParamMap): GBTRegressionModel = {
+ val categoricalFeatures: Map[Int, Int] =
+ MetadataUtils.getCategoricalFeatures(dataset.schema(paramMap(featuresCol)))
+ val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, paramMap)
+ val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression)
+ val oldGBT = new OldGBT(boostingStrategy)
+ val oldModel = oldGBT.run(oldDataset)
+ GBTRegressionModel.fromOld(oldModel, this, paramMap, categoricalFeatures)
+ }
+}
+
+object GBTRegressor {
+ // The losses below should be lowercase.
+ /** Accessor for supported loss settings: squared (L2), absolute (L1) */
+ final val supportedLossTypes: Array[String] = Array("squared", "absolute").map(_.toLowerCase)
+}
+
+/**
+ * :: AlphaComponent ::
+ *
+ * [[http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)]]
+ * model for regression.
+ * It supports both continuous and categorical features.
+ * @param _trees Decision trees in the ensemble.
+ * @param _treeWeights Weights for the decision trees in the ensemble.
+ */
+@AlphaComponent
+final class GBTRegressionModel(
+ override val parent: GBTRegressor,
+ override val fittingParamMap: ParamMap,
+ private val _trees: Array[DecisionTreeRegressionModel],
+ private val _treeWeights: Array[Double])
+ extends PredictionModel[Vector, GBTRegressionModel]
+ with TreeEnsembleModel with Serializable {
+
+ require(numTrees > 0, "GBTRegressionModel requires at least 1 tree.")
+ require(_trees.length == _treeWeights.length, "GBTRegressionModel given trees, treeWeights of" +
+ s" non-matching lengths (${_trees.length}, ${_treeWeights.length}, respectively).")
+
+ override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
+
+ override def treeWeights: Array[Double] = _treeWeights
+
+ override protected def predict(features: Vector): Double = {
+ // TODO: Override transform() to broadcast model. SPARK-7127
+ // TODO: When we add a generic Boosting class, handle transform there? SPARK-7129
+ // Classifies by thresholding sum of weighted tree predictions
+ val treePredictions = _trees.map(_.rootNode.predict(features))
+ val prediction = blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1)
+ if (prediction > 0.0) 1.0 else 0.0
+ }
+
+ override protected def copy(): GBTRegressionModel = {
+ val m = new GBTRegressionModel(parent, fittingParamMap, _trees, _treeWeights)
+ Params.inheritValues(this.extractParamMap(), this, m)
+ m
+ }
+
+ override def toString: String = {
+ s"GBTRegressionModel with $numTrees trees"
+ }
+
+ /** (private[ml]) Convert to a model in the old API */
+ private[ml] def toOld: OldGBTModel = {
+ new OldGBTModel(OldAlgo.Regression, _trees.map(_.toOld), _treeWeights)
+ }
+}
+
+private[ml] object GBTRegressionModel {
+
+ /** (private[ml]) Convert a model from the old API */
+ def fromOld(
+ oldModel: OldGBTModel,
+ parent: GBTRegressor,
+ fittingParamMap: ParamMap,
+ categoricalFeatures: Map[Int, Int]): GBTRegressionModel = {
+ require(oldModel.algo == OldAlgo.Regression, "Cannot convert GradientBoostedTreesModel" +
+ s" with algo=${oldModel.algo} (old API) to GBTRegressionModel (new API).")
+ val newTrees = oldModel.trees.map { tree =>
+ // parent, fittingParamMap for each tree is null since there are no good ways to set these.
+ DecisionTreeRegressionModel.fromOld(tree, null, null, categoricalFeatures)
+ }
+ new GBTRegressionModel(parent, fittingParamMap, newTrees, oldModel.treeWeights)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
new file mode 100644
index 0000000000000..2171ef3d32c26
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
@@ -0,0 +1,167 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.regression
+
+import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor}
+import org.apache.spark.ml.impl.tree.{RandomForestParams, TreeRegressorParams}
+import org.apache.spark.ml.param.{Params, ParamMap}
+import org.apache.spark.ml.tree.{DecisionTreeModel, TreeEnsembleModel}
+import org.apache.spark.ml.util.MetadataUtils
+import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.{RandomForest => OldRandomForest}
+import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
+import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.DataFrame
+
+
+/**
+ * :: AlphaComponent ::
+ *
+ * [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] learning algorithm for regression.
+ * It supports both continuous and categorical features.
+ */
+@AlphaComponent
+final class RandomForestRegressor
+ extends Predictor[Vector, RandomForestRegressor, RandomForestRegressionModel]
+ with RandomForestParams with TreeRegressorParams {
+
+ // Override parameter setters from parent trait for Java API compatibility.
+
+ // Parameters from TreeRegressorParams:
+
+ override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value)
+
+ override def setMaxBins(value: Int): this.type = super.setMaxBins(value)
+
+ override def setMinInstancesPerNode(value: Int): this.type =
+ super.setMinInstancesPerNode(value)
+
+ override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value)
+
+ override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value)
+
+ override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value)
+
+ override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value)
+
+ override def setImpurity(value: String): this.type = super.setImpurity(value)
+
+ // Parameters from TreeEnsembleParams:
+
+ override def setSubsamplingRate(value: Double): this.type = super.setSubsamplingRate(value)
+
+ override def setSeed(value: Long): this.type = super.setSeed(value)
+
+ // Parameters from RandomForestParams:
+
+ override def setNumTrees(value: Int): this.type = super.setNumTrees(value)
+
+ override def setFeatureSubsetStrategy(value: String): this.type =
+ super.setFeatureSubsetStrategy(value)
+
+ override protected def train(
+ dataset: DataFrame,
+ paramMap: ParamMap): RandomForestRegressionModel = {
+ val categoricalFeatures: Map[Int, Int] =
+ MetadataUtils.getCategoricalFeatures(dataset.schema(paramMap(featuresCol)))
+ val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, paramMap)
+ val strategy =
+ super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity)
+ val oldModel = OldRandomForest.trainRegressor(
+ oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed.toInt)
+ RandomForestRegressionModel.fromOld(oldModel, this, paramMap, categoricalFeatures)
+ }
+}
+
+object RandomForestRegressor {
+ /** Accessor for supported impurity settings: variance */
+ final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities
+
+ /** Accessor for supported featureSubsetStrategy settings: auto, all, onethird, sqrt, log2 */
+ final val supportedFeatureSubsetStrategies: Array[String] =
+ RandomForestParams.supportedFeatureSubsetStrategies
+}
+
+/**
+ * :: AlphaComponent ::
+ *
+ * [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] model for regression.
+ * It supports both continuous and categorical features.
+ * @param _trees Decision trees in the ensemble.
+ */
+@AlphaComponent
+final class RandomForestRegressionModel private[ml] (
+ override val parent: RandomForestRegressor,
+ override val fittingParamMap: ParamMap,
+ private val _trees: Array[DecisionTreeRegressionModel])
+ extends PredictionModel[Vector, RandomForestRegressionModel]
+ with TreeEnsembleModel with Serializable {
+
+ require(numTrees > 0, "RandomForestRegressionModel requires at least 1 tree.")
+
+ override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
+
+ // Note: We may add support for weights (based on tree performance) later on.
+ private lazy val _treeWeights: Array[Double] = Array.fill[Double](numTrees)(1.0)
+
+ override def treeWeights: Array[Double] = _treeWeights
+
+ override protected def predict(features: Vector): Double = {
+ // TODO: Override transform() to broadcast model. SPARK-7127
+ // TODO: When we add a generic Bagging class, handle transform there. SPARK-7128
+ // Predict average of tree predictions.
+ // Ignore the weights since all are 1.0 for now.
+ _trees.map(_.rootNode.predict(features)).sum / numTrees
+ }
+
+ override protected def copy(): RandomForestRegressionModel = {
+ val m = new RandomForestRegressionModel(parent, fittingParamMap, _trees)
+ Params.inheritValues(this.extractParamMap(), this, m)
+ m
+ }
+
+ override def toString: String = {
+ s"RandomForestRegressionModel with $numTrees trees"
+ }
+
+ /** (private[ml]) Convert to a model in the old API */
+ private[ml] def toOld: OldRandomForestModel = {
+ new OldRandomForestModel(OldAlgo.Regression, _trees.map(_.toOld))
+ }
+}
+
+private[ml] object RandomForestRegressionModel {
+
+ /** (private[ml]) Convert a model from the old API */
+ def fromOld(
+ oldModel: OldRandomForestModel,
+ parent: RandomForestRegressor,
+ fittingParamMap: ParamMap,
+ categoricalFeatures: Map[Int, Int]): RandomForestRegressionModel = {
+ require(oldModel.algo == OldAlgo.Regression, "Cannot convert RandomForestModel" +
+ s" with algo=${oldModel.algo} (old API) to RandomForestRegressionModel (new API).")
+ val newTrees = oldModel.trees.map { tree =>
+ // parent, fittingParamMap for each tree is null since there are no good ways to set these.
+ DecisionTreeRegressionModel.fromOld(tree, null, null, categoricalFeatures)
+ }
+ new RandomForestRegressionModel(parent, fittingParamMap, newTrees)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
index d6e2203d9f937..d2dec0c76cb12 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
@@ -28,9 +28,9 @@ import org.apache.spark.mllib.tree.model.{InformationGainStats => OldInformation
sealed abstract class Node extends Serializable {
// TODO: Add aggregate stats (once available). This will happen after we move the DecisionTree
- // code into the new API and deprecate the old API.
+ // code into the new API and deprecate the old API. SPARK-3727
- /** Prediction this node makes (or would make, if it is an internal node) */
+ /** Prediction a leaf node makes, or which an internal node would make if it were a leaf node */
def prediction: Double
/** Impurity measure at this node (for training data) */
@@ -194,7 +194,7 @@ private object InternalNode {
s"$featureStr > ${contSplit.threshold}"
}
case catSplit: CategoricalSplit =>
- val categoriesStr = catSplit.getLeftCategories.mkString("{", ",", "}")
+ val categoriesStr = catSplit.leftCategories.mkString("{", ",", "}")
if (left) {
s"$featureStr in $categoriesStr"
} else {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala
index cb940f62990ed..90f1d052764d3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala
@@ -38,13 +38,13 @@ sealed trait Split extends Serializable {
private[tree] def toOld: OldSplit
}
-private[ml] object Split {
+private[tree] object Split {
def fromOld(oldSplit: OldSplit, categoricalFeatures: Map[Int, Int]): Split = {
oldSplit.featureType match {
case OldFeatureType.Categorical =>
new CategoricalSplit(featureIndex = oldSplit.feature,
- leftCategories = oldSplit.categories.toArray, categoricalFeatures(oldSplit.feature))
+ _leftCategories = oldSplit.categories.toArray, categoricalFeatures(oldSplit.feature))
case OldFeatureType.Continuous =>
new ContinuousSplit(featureIndex = oldSplit.feature, threshold = oldSplit.threshold)
}
@@ -54,30 +54,30 @@ private[ml] object Split {
/**
* Split which tests a categorical feature.
* @param featureIndex Index of the feature to test
- * @param leftCategories If the feature value is in this set of categories, then the split goes
- * left. Otherwise, it goes right.
+ * @param _leftCategories If the feature value is in this set of categories, then the split goes
+ * left. Otherwise, it goes right.
* @param numCategories Number of categories for this feature.
*/
-final class CategoricalSplit(
+final class CategoricalSplit private[ml] (
override val featureIndex: Int,
- leftCategories: Array[Double],
+ _leftCategories: Array[Double],
private val numCategories: Int)
extends Split {
- require(leftCategories.forall(cat => 0 <= cat && cat < numCategories), "Invalid leftCategories" +
- s" (should be in range [0, $numCategories)): ${leftCategories.mkString(",")}")
+ require(_leftCategories.forall(cat => 0 <= cat && cat < numCategories), "Invalid leftCategories" +
+ s" (should be in range [0, $numCategories)): ${_leftCategories.mkString(",")}")
/**
* If true, then "categories" is the set of categories for splitting to the left, and vice versa.
*/
- private val isLeft: Boolean = leftCategories.length <= numCategories / 2
+ private val isLeft: Boolean = _leftCategories.length <= numCategories / 2
/** Set of categories determining the splitting rule, along with [[isLeft]]. */
private val categories: Set[Double] = {
if (isLeft) {
- leftCategories.toSet
+ _leftCategories.toSet
} else {
- setComplement(leftCategories.toSet)
+ setComplement(_leftCategories.toSet)
}
}
@@ -107,13 +107,13 @@ final class CategoricalSplit(
}
/** Get sorted categories which split to the left */
- def getLeftCategories: Array[Double] = {
+ def leftCategories: Array[Double] = {
val cats = if (isLeft) categories else setComplement(categories)
cats.toArray.sorted
}
/** Get sorted categories which split to the right */
- def getRightCategories: Array[Double] = {
+ def rightCategories: Array[Double] = {
val cats = if (isLeft) setComplement(categories) else categories
cats.toArray.sorted
}
@@ -130,7 +130,8 @@ final class CategoricalSplit(
* @param threshold If the feature value is <= this threshold, then the split goes left.
* Otherwise, it goes right.
*/
-final class ContinuousSplit(override val featureIndex: Int, val threshold: Double) extends Split {
+final class ContinuousSplit private[ml] (override val featureIndex: Int, val threshold: Double)
+ extends Split {
override private[ml] def shouldGoLeft(features: Vector): Boolean = {
features(featureIndex) <= threshold
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
index 8e3bc3849dcf0..1929f9d02156e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
@@ -17,18 +17,13 @@
package org.apache.spark.ml.tree
-import org.apache.spark.annotation.AlphaComponent
-
/**
- * :: AlphaComponent ::
- *
* Abstraction for Decision Tree models.
*
- * TODO: Add support for predicting probabilities and raw predictions
+ * TODO: Add support for predicting probabilities and raw predictions SPARK-3727
*/
-@AlphaComponent
-trait DecisionTreeModel {
+private[ml] trait DecisionTreeModel {
/** Root of the decision tree */
def rootNode: Node
@@ -58,3 +53,40 @@ trait DecisionTreeModel {
header + rootNode.subtreeToString(2)
}
}
+
+/**
+ * Abstraction for models which are ensembles of decision trees
+ *
+ * TODO: Add support for predicting probabilities and raw predictions SPARK-3727
+ */
+private[ml] trait TreeEnsembleModel {
+
+ // Note: We use getTrees since subclasses of TreeEnsembleModel will store subclasses of
+ // DecisionTreeModel.
+
+ /** Trees in this ensemble. Warning: These have null parent Estimators. */
+ def trees: Array[DecisionTreeModel]
+
+ /** Weights for each tree, zippable with [[trees]] */
+ def treeWeights: Array[Double]
+
+ /** Summary of the model */
+ override def toString: String = {
+ // Implementing classes should generally override this method to be more descriptive.
+ s"TreeEnsembleModel with $numTrees trees"
+ }
+
+ /** Full description of model */
+ def toDebugString: String = {
+ val header = toString + "\n"
+ header + trees.zip(treeWeights).zipWithIndex.map { case ((tree, weight), treeIndex) =>
+ s" Tree $treeIndex (weight $weight):\n" + tree.rootNode.subtreeToString(4)
+ }.fold("")(_ + _)
+ }
+
+ /** Number of trees in ensemble */
+ val numTrees: Int = trees.length
+
+ /** Total number of nodes, summed over all trees in the ensemble. */
+ lazy val totalNumNodes: Int = trees.map(_.numNodes).sum
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
index 9d63a08e211bc..d006b39acb213 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
@@ -177,7 +177,7 @@ class LDA private (
def getBeta: Double = getTopicConcentration
/** Alias for [[setTopicConcentration()]] */
- def setBeta(beta: Double): this.type = setBeta(beta)
+ def setBeta(beta: Double): this.type = setTopicConcentration(beta)
/**
* Maximum number of iterations for learning.
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
index 4ef171f4f0419..166c00cff634d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
@@ -526,7 +526,7 @@ class SparseVector(
s" ${values.size} values.")
override def toString: String =
- "(%s,%s,%s)".format(size, indices.mkString("[", ",", "]"), values.mkString("[", ",", "]"))
+ s"($size,${indices.mkString("[", ",", "]")},${values.mkString("[", ",", "]")})"
override def toArray: Array[Double] = {
val data = new Array[Double](size)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala
index 2067b36f246b3..d5fea822ad77b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala
@@ -32,7 +32,7 @@ import org.apache.spark.SparkException
@BeanInfo
case class LabeledPoint(label: Double, features: Vector) {
override def toString: String = {
- "(%s,%s)".format(label, features)
+ s"($label,$features)"
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
index 8838ca8c14718..309f9af466457 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
@@ -171,7 +171,7 @@ object RidgeRegressionWithSGD {
numIterations: Int,
stepSize: Double,
regParam: Double): RidgeRegressionModel = {
- train(input, numIterations, stepSize, regParam, 0.01)
+ train(input, numIterations, stepSize, regParam, 1.0)
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
index f209fdafd3653..2d087c967f679 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
@@ -39,8 +39,8 @@ class InformationGainStats(
val rightPredict: Predict) extends Serializable {
override def toString: String = {
- "gain = %f, impurity = %f, left impurity = %f, right impurity = %f"
- .format(gain, impurity, leftImpurity, rightImpurity)
+ s"gain = $gain, impurity = $impurity, left impurity = $leftImpurity, " +
+ s"right impurity = $rightImpurity"
}
override def equals(o: Any): Boolean = o match {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
index 86390a20cb5cc..431a839817eac 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
@@ -51,8 +51,8 @@ class Node (
var stats: Option[InformationGainStats]) extends Serializable with Logging {
override def toString: String = {
- "id = " + id + ", isLeaf = " + isLeaf + ", predict = " + predict + ", " +
- "impurity = " + impurity + ", split = " + split + ", stats = " + stats
+ s"id = $id, isLeaf = $isLeaf, predict = $predict, impurity = $impurity, " +
+ s"split = $split, stats = $stats"
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala
index 25990af7c6cf7..5cbe7c280dbee 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala
@@ -29,9 +29,7 @@ class Predict(
val predict: Double,
val prob: Double = 0.0) extends Serializable {
- override def toString: String = {
- "predict = %f, prob = %f".format(predict, prob)
- }
+ override def toString: String = s"$predict (prob = $prob)"
override def equals(other: Any): Boolean = {
other match {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala
index fb35e70a8d077..be6c9b3de5479 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala
@@ -39,8 +39,8 @@ case class Split(
categories: List[Double]) {
override def toString: String = {
- "Feature = " + feature + ", threshold = " + threshold + ", featureType = " + featureType +
- ", categories = " + categories
+ s"Feature = $feature, threshold = $threshold, featureType = $featureType, " +
+ s"categories = $categories"
}
}
@@ -68,4 +68,3 @@ private[tree] class DummyHighSplit(feature: Int, featureType: FeatureType)
*/
private[tree] class DummyCategoricalSplit(feature: Int, featureType: FeatureType)
extends Split(feature, Double.MaxValue, featureType, List())
-
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java
index 43b8787f9dd7e..60f25e5cce437 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java
@@ -17,7 +17,6 @@
package org.apache.spark.ml.classification;
-import java.io.File;
import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;
@@ -32,7 +31,6 @@
import org.apache.spark.mllib.classification.LogisticRegressionSuite;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.DataFrame;
-import org.apache.spark.util.Utils;
public class JavaDecisionTreeClassifierSuite implements Serializable {
@@ -57,7 +55,7 @@ public void runDT() {
double B = -1.5;
JavaRDD data = sc.parallelize(
- LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
+ LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
Map categoricalFeatures = new HashMap();
DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2);
@@ -71,8 +69,8 @@ public void runDT() {
.setCacheNodeIds(false)
.setCheckpointInterval(10)
.setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
- for (int i = 0; i < DecisionTreeClassifier.supportedImpurities().length; ++i) {
- dt.setImpurity(DecisionTreeClassifier.supportedImpurities()[i]);
+ for (String impurity: DecisionTreeClassifier.supportedImpurities()) {
+ dt.setImpurity(impurity);
}
DecisionTreeClassificationModel model = dt.fit(dataFrame);
@@ -82,7 +80,7 @@ public void runDT() {
model.toDebugString();
/*
- // TODO: Add test once save/load are implemented.
+ // TODO: Add test once save/load are implemented. SPARK-6725
File tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "spark");
String path = tempDir.toURI().toString();
try {
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java
new file mode 100644
index 0000000000000..3c69467fa119e
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java
@@ -0,0 +1,100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification;
+
+import java.io.Serializable;
+import java.util.HashMap;
+import java.util.Map;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.ml.impl.TreeTests;
+import org.apache.spark.mllib.classification.LogisticRegressionSuite;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.sql.DataFrame;
+
+
+public class JavaGBTClassifierSuite implements Serializable {
+
+ private transient JavaSparkContext sc;
+
+ @Before
+ public void setUp() {
+ sc = new JavaSparkContext("local", "JavaGBTClassifierSuite");
+ }
+
+ @After
+ public void tearDown() {
+ sc.stop();
+ sc = null;
+ }
+
+ @Test
+ public void runDT() {
+ int nPoints = 20;
+ double A = 2.0;
+ double B = -1.5;
+
+ JavaRDD data = sc.parallelize(
+ LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
+ Map categoricalFeatures = new HashMap();
+ DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2);
+
+ // This tests setters. Training with various options is tested in Scala.
+ GBTClassifier rf = new GBTClassifier()
+ .setMaxDepth(2)
+ .setMaxBins(10)
+ .setMinInstancesPerNode(5)
+ .setMinInfoGain(0.0)
+ .setMaxMemoryInMB(256)
+ .setCacheNodeIds(false)
+ .setCheckpointInterval(10)
+ .setSubsamplingRate(1.0)
+ .setSeed(1234)
+ .setMaxIter(3)
+ .setStepSize(0.1)
+ .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
+ for (String lossType: GBTClassifier.supportedLossTypes()) {
+ rf.setLossType(lossType);
+ }
+ GBTClassificationModel model = rf.fit(dataFrame);
+
+ model.transform(dataFrame);
+ model.totalNumNodes();
+ model.toDebugString();
+ model.trees();
+ model.treeWeights();
+
+ /*
+ // TODO: Add test once save/load are implemented. SPARK-6725
+ File tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "spark");
+ String path = tempDir.toURI().toString();
+ try {
+ model3.save(sc.sc(), path);
+ GBTClassificationModel sameModel = GBTClassificationModel.load(sc.sc(), path);
+ TreeTests.checkEqual(model3, sameModel);
+ } finally {
+ Utils.deleteRecursively(tempDir);
+ }
+ */
+ }
+}
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java
new file mode 100644
index 0000000000000..32d0b3856b7e2
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification;
+
+import java.io.Serializable;
+import java.util.HashMap;
+import java.util.Map;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.ml.impl.TreeTests;
+import org.apache.spark.mllib.classification.LogisticRegressionSuite;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.sql.DataFrame;
+
+
+public class JavaRandomForestClassifierSuite implements Serializable {
+
+ private transient JavaSparkContext sc;
+
+ @Before
+ public void setUp() {
+ sc = new JavaSparkContext("local", "JavaRandomForestClassifierSuite");
+ }
+
+ @After
+ public void tearDown() {
+ sc.stop();
+ sc = null;
+ }
+
+ @Test
+ public void runDT() {
+ int nPoints = 20;
+ double A = 2.0;
+ double B = -1.5;
+
+ JavaRDD data = sc.parallelize(
+ LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
+ Map categoricalFeatures = new HashMap();
+ DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2);
+
+ // This tests setters. Training with various options is tested in Scala.
+ RandomForestClassifier rf = new RandomForestClassifier()
+ .setMaxDepth(2)
+ .setMaxBins(10)
+ .setMinInstancesPerNode(5)
+ .setMinInfoGain(0.0)
+ .setMaxMemoryInMB(256)
+ .setCacheNodeIds(false)
+ .setCheckpointInterval(10)
+ .setSubsamplingRate(1.0)
+ .setSeed(1234)
+ .setNumTrees(3)
+ .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
+ for (String impurity: RandomForestClassifier.supportedImpurities()) {
+ rf.setImpurity(impurity);
+ }
+ for (String featureSubsetStrategy: RandomForestClassifier.supportedFeatureSubsetStrategies()) {
+ rf.setFeatureSubsetStrategy(featureSubsetStrategy);
+ }
+ RandomForestClassificationModel model = rf.fit(dataFrame);
+
+ model.transform(dataFrame);
+ model.totalNumNodes();
+ model.toDebugString();
+ model.trees();
+ model.treeWeights();
+
+ /*
+ // TODO: Add test once save/load are implemented. SPARK-6725
+ File tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "spark");
+ String path = tempDir.toURI().toString();
+ try {
+ model3.save(sc.sc(), path);
+ RandomForestClassificationModel sameModel =
+ RandomForestClassificationModel.load(sc.sc(), path);
+ TreeTests.checkEqual(model3, sameModel);
+ } finally {
+ Utils.deleteRecursively(tempDir);
+ }
+ */
+ }
+}
diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java
index a3a339004f31c..71b041818d7ee 100644
--- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java
@@ -17,7 +17,6 @@
package org.apache.spark.ml.regression;
-import java.io.File;
import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;
@@ -32,7 +31,6 @@
import org.apache.spark.mllib.classification.LogisticRegressionSuite;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.DataFrame;
-import org.apache.spark.util.Utils;
public class JavaDecisionTreeRegressorSuite implements Serializable {
@@ -57,22 +55,22 @@ public void runDT() {
double B = -1.5;
JavaRDD data = sc.parallelize(
- LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
+ LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
Map categoricalFeatures = new HashMap();
DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2);
// This tests setters. Training with various options is tested in Scala.
DecisionTreeRegressor dt = new DecisionTreeRegressor()
- .setMaxDepth(2)
- .setMaxBins(10)
- .setMinInstancesPerNode(5)
- .setMinInfoGain(0.0)
- .setMaxMemoryInMB(256)
- .setCacheNodeIds(false)
- .setCheckpointInterval(10)
- .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
- for (int i = 0; i < DecisionTreeRegressor.supportedImpurities().length; ++i) {
- dt.setImpurity(DecisionTreeRegressor.supportedImpurities()[i]);
+ .setMaxDepth(2)
+ .setMaxBins(10)
+ .setMinInstancesPerNode(5)
+ .setMinInfoGain(0.0)
+ .setMaxMemoryInMB(256)
+ .setCacheNodeIds(false)
+ .setCheckpointInterval(10)
+ .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
+ for (String impurity: DecisionTreeRegressor.supportedImpurities()) {
+ dt.setImpurity(impurity);
}
DecisionTreeRegressionModel model = dt.fit(dataFrame);
@@ -82,7 +80,7 @@ public void runDT() {
model.toDebugString();
/*
- // TODO: Add test once save/load are implemented.
+ // TODO: Add test once save/load are implemented. SPARK-6725
File tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "spark");
String path = tempDir.toURI().toString();
try {
diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java
new file mode 100644
index 0000000000000..fc8c13db07e6f
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java
@@ -0,0 +1,99 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.regression;
+
+import java.io.Serializable;
+import java.util.HashMap;
+import java.util.Map;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.ml.impl.TreeTests;
+import org.apache.spark.mllib.classification.LogisticRegressionSuite;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.sql.DataFrame;
+
+
+public class JavaGBTRegressorSuite implements Serializable {
+
+ private transient JavaSparkContext sc;
+
+ @Before
+ public void setUp() {
+ sc = new JavaSparkContext("local", "JavaGBTRegressorSuite");
+ }
+
+ @After
+ public void tearDown() {
+ sc.stop();
+ sc = null;
+ }
+
+ @Test
+ public void runDT() {
+ int nPoints = 20;
+ double A = 2.0;
+ double B = -1.5;
+
+ JavaRDD data = sc.parallelize(
+ LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
+ Map categoricalFeatures = new HashMap();
+ DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0);
+
+ GBTRegressor rf = new GBTRegressor()
+ .setMaxDepth(2)
+ .setMaxBins(10)
+ .setMinInstancesPerNode(5)
+ .setMinInfoGain(0.0)
+ .setMaxMemoryInMB(256)
+ .setCacheNodeIds(false)
+ .setCheckpointInterval(10)
+ .setSubsamplingRate(1.0)
+ .setSeed(1234)
+ .setMaxIter(3)
+ .setStepSize(0.1)
+ .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
+ for (String lossType: GBTRegressor.supportedLossTypes()) {
+ rf.setLossType(lossType);
+ }
+ GBTRegressionModel model = rf.fit(dataFrame);
+
+ model.transform(dataFrame);
+ model.totalNumNodes();
+ model.toDebugString();
+ model.trees();
+ model.treeWeights();
+
+ /*
+ // TODO: Add test once save/load are implemented. SPARK-6725
+ File tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "spark");
+ String path = tempDir.toURI().toString();
+ try {
+ model2.save(sc.sc(), path);
+ GBTRegressionModel sameModel = GBTRegressionModel.load(sc.sc(), path);
+ TreeTests.checkEqual(model2, sameModel);
+ } finally {
+ Utils.deleteRecursively(tempDir);
+ }
+ */
+ }
+}
diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java
new file mode 100644
index 0000000000000..e306ebadfe7cf
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.regression;
+
+import java.io.Serializable;
+import java.util.HashMap;
+import java.util.Map;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.mllib.classification.LogisticRegressionSuite;
+import org.apache.spark.ml.impl.TreeTests;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.sql.DataFrame;
+
+
+public class JavaRandomForestRegressorSuite implements Serializable {
+
+ private transient JavaSparkContext sc;
+
+ @Before
+ public void setUp() {
+ sc = new JavaSparkContext("local", "JavaRandomForestRegressorSuite");
+ }
+
+ @After
+ public void tearDown() {
+ sc.stop();
+ sc = null;
+ }
+
+ @Test
+ public void runDT() {
+ int nPoints = 20;
+ double A = 2.0;
+ double B = -1.5;
+
+ JavaRDD data = sc.parallelize(
+ LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
+ Map categoricalFeatures = new HashMap();
+ DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0);
+
+ // This tests setters. Training with various options is tested in Scala.
+ RandomForestRegressor rf = new RandomForestRegressor()
+ .setMaxDepth(2)
+ .setMaxBins(10)
+ .setMinInstancesPerNode(5)
+ .setMinInfoGain(0.0)
+ .setMaxMemoryInMB(256)
+ .setCacheNodeIds(false)
+ .setCheckpointInterval(10)
+ .setSubsamplingRate(1.0)
+ .setSeed(1234)
+ .setNumTrees(3)
+ .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
+ for (String impurity: RandomForestRegressor.supportedImpurities()) {
+ rf.setImpurity(impurity);
+ }
+ for (String featureSubsetStrategy: RandomForestRegressor.supportedFeatureSubsetStrategies()) {
+ rf.setFeatureSubsetStrategy(featureSubsetStrategy);
+ }
+ RandomForestRegressionModel model = rf.fit(dataFrame);
+
+ model.transform(dataFrame);
+ model.totalNumNodes();
+ model.toDebugString();
+ model.trees();
+ model.treeWeights();
+
+ /*
+ // TODO: Add test once save/load are implemented. SPARK-6725
+ File tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "spark");
+ String path = tempDir.toURI().toString();
+ try {
+ model2.save(sc.sc(), path);
+ RandomForestRegressionModel sameModel = RandomForestRegressionModel.load(sc.sc(), path);
+ TreeTests.checkEqual(model2, sameModel);
+ } finally {
+ Utils.deleteRecursively(tempDir);
+ }
+ */
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
index af88595df5245..9b31adecdcb1c 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
@@ -230,7 +230,7 @@ class DecisionTreeClassifierSuite extends FunSuite with MLlibTestSparkContext {
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////
- // TODO: Reinstate test once save/load are implemented
+ // TODO: Reinstate test once save/load are implemented SPARK-6725
/*
test("model save/load") {
val tempDir = Utils.createTempDir()
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
new file mode 100644
index 0000000000000..e6ccc2c93cba8
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
@@ -0,0 +1,136 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.ml.impl.TreeTests
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT}
+import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.DataFrame
+
+
+/**
+ * Test suite for [[GBTClassifier]].
+ */
+class GBTClassifierSuite extends FunSuite with MLlibTestSparkContext {
+
+ import GBTClassifierSuite.compareAPIs
+
+ // Combinations for estimators, learning rates and subsamplingRate
+ private val testCombinations =
+ Array((10, 1.0, 1.0), (10, 0.1, 1.0), (10, 0.5, 0.75), (10, 0.1, 0.75))
+
+ private var data: RDD[LabeledPoint] = _
+ private var trainData: RDD[LabeledPoint] = _
+ private var validationData: RDD[LabeledPoint] = _
+
+ override def beforeAll() {
+ super.beforeAll()
+ data = sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 10, 100), 2)
+ trainData =
+ sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 120), 2)
+ validationData =
+ sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 80), 2)
+ }
+
+ test("Binary classification with continuous features: Log Loss") {
+ val categoricalFeatures = Map.empty[Int, Int]
+ testCombinations.foreach {
+ case (maxIter, learningRate, subsamplingRate) =>
+ val gbt = new GBTClassifier()
+ .setMaxDepth(2)
+ .setSubsamplingRate(subsamplingRate)
+ .setLossType("logistic")
+ .setMaxIter(maxIter)
+ .setStepSize(learningRate)
+ compareAPIs(data, None, gbt, categoricalFeatures)
+ }
+ }
+
+ // TODO: Reinstate test once runWithValidation is implemented SPARK-7132
+ /*
+ test("runWithValidation stops early and performs better on a validation dataset") {
+ val categoricalFeatures = Map.empty[Int, Int]
+ // Set maxIter large enough so that it stops early.
+ val maxIter = 20
+ GBTClassifier.supportedLossTypes.foreach { loss =>
+ val gbt = new GBTClassifier()
+ .setMaxIter(maxIter)
+ .setMaxDepth(2)
+ .setLossType(loss)
+ .setValidationTol(0.0)
+ compareAPIs(trainData, None, gbt, categoricalFeatures)
+ compareAPIs(trainData, Some(validationData), gbt, categoricalFeatures)
+ }
+ }
+ */
+
+ /////////////////////////////////////////////////////////////////////////////
+ // Tests of model save/load
+ /////////////////////////////////////////////////////////////////////////////
+
+ // TODO: Reinstate test once save/load are implemented SPARK-6725
+ /*
+ test("model save/load") {
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+
+ val trees = Range(0, 3).map(_ => OldDecisionTreeSuite.createModel(OldAlgo.Regression)).toArray
+ val treeWeights = Array(0.1, 0.3, 1.1)
+ val oldModel = new OldGBTModel(OldAlgo.Classification, trees, treeWeights)
+ val newModel = GBTClassificationModel.fromOld(oldModel)
+
+ // Save model, load it back, and compare.
+ try {
+ newModel.save(sc, path)
+ val sameNewModel = GBTClassificationModel.load(sc, path)
+ TreeTests.checkEqual(newModel, sameNewModel)
+ } finally {
+ Utils.deleteRecursively(tempDir)
+ }
+ }
+ */
+}
+
+private object GBTClassifierSuite {
+
+ /**
+ * Train 2 models on the given dataset, one using the old API and one using the new API.
+ * Convert the old model to the new format, compare them, and fail if they are not exactly equal.
+ */
+ def compareAPIs(
+ data: RDD[LabeledPoint],
+ validationData: Option[RDD[LabeledPoint]],
+ gbt: GBTClassifier,
+ categoricalFeatures: Map[Int, Int]): Unit = {
+ val oldBoostingStrategy =
+ gbt.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification)
+ val oldGBT = new OldGBT(oldBoostingStrategy)
+ val oldModel = oldGBT.run(data)
+ val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 2)
+ val newModel = gbt.fit(newData)
+ // Use parent, fittingParamMap from newTree since these are not checked anyways.
+ val oldModelAsNew = GBTClassificationModel.fromOld(oldModel, newModel.parent,
+ newModel.fittingParamMap, categoricalFeatures)
+ TreeTests.checkEqual(oldModelAsNew, newModel)
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
new file mode 100644
index 0000000000000..ed41a9664f94f
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
@@ -0,0 +1,166 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.ml.impl.TreeTests
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest}
+import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.DataFrame
+
+
+/**
+ * Test suite for [[RandomForestClassifier]].
+ */
+class RandomForestClassifierSuite extends FunSuite with MLlibTestSparkContext {
+
+ import RandomForestClassifierSuite.compareAPIs
+
+ private var orderedLabeledPoints50_1000: RDD[LabeledPoint] = _
+ private var orderedLabeledPoints5_20: RDD[LabeledPoint] = _
+
+ override def beforeAll() {
+ super.beforeAll()
+ orderedLabeledPoints50_1000 =
+ sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000))
+ orderedLabeledPoints5_20 =
+ sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 5, 20))
+ }
+
+ /////////////////////////////////////////////////////////////////////////////
+ // Tests calling train()
+ /////////////////////////////////////////////////////////////////////////////
+
+ def binaryClassificationTestWithContinuousFeatures(rf: RandomForestClassifier) {
+ val categoricalFeatures = Map.empty[Int, Int]
+ val numClasses = 2
+ val newRF = rf
+ .setImpurity("Gini")
+ .setMaxDepth(2)
+ .setNumTrees(1)
+ .setFeatureSubsetStrategy("auto")
+ .setSeed(123)
+ compareAPIs(orderedLabeledPoints50_1000, newRF, categoricalFeatures, numClasses)
+ }
+
+ test("Binary classification with continuous features:" +
+ " comparing DecisionTree vs. RandomForest(numTrees = 1)") {
+ val rf = new RandomForestClassifier()
+ binaryClassificationTestWithContinuousFeatures(rf)
+ }
+
+ test("Binary classification with continuous features and node Id cache:" +
+ " comparing DecisionTree vs. RandomForest(numTrees = 1)") {
+ val rf = new RandomForestClassifier()
+ .setCacheNodeIds(true)
+ binaryClassificationTestWithContinuousFeatures(rf)
+ }
+
+ test("alternating categorical and continuous features with multiclass labels to test indexing") {
+ val arr = Array(
+ LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0, 3.0, 1.0)),
+ LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0, 1.0, 2.0)),
+ LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0, 6.0, 3.0)),
+ LabeledPoint(2.0, Vectors.dense(0.0, 2.0, 1.0, 3.0, 2.0))
+ )
+ val rdd = sc.parallelize(arr)
+ val categoricalFeatures = Map(0 -> 3, 2 -> 2, 4 -> 4)
+ val numClasses = 3
+
+ val rf = new RandomForestClassifier()
+ .setImpurity("Gini")
+ .setMaxDepth(5)
+ .setNumTrees(2)
+ .setFeatureSubsetStrategy("sqrt")
+ .setSeed(12345)
+ compareAPIs(rdd, rf, categoricalFeatures, numClasses)
+ }
+
+ test("subsampling rate in RandomForest"){
+ val rdd = orderedLabeledPoints5_20
+ val categoricalFeatures = Map.empty[Int, Int]
+ val numClasses = 2
+
+ val rf1 = new RandomForestClassifier()
+ .setImpurity("Gini")
+ .setMaxDepth(2)
+ .setCacheNodeIds(true)
+ .setNumTrees(3)
+ .setFeatureSubsetStrategy("auto")
+ .setSeed(123)
+ compareAPIs(rdd, rf1, categoricalFeatures, numClasses)
+
+ val rf2 = rf1.setSubsamplingRate(0.5)
+ compareAPIs(rdd, rf2, categoricalFeatures, numClasses)
+ }
+
+ /////////////////////////////////////////////////////////////////////////////
+ // Tests of model save/load
+ /////////////////////////////////////////////////////////////////////////////
+
+ // TODO: Reinstate test once save/load are implemented SPARK-6725
+ /*
+ test("model save/load") {
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+
+ val trees =
+ Range(0, 3).map(_ => OldDecisionTreeSuite.createModel(OldAlgo.Classification)).toArray
+ val oldModel = new OldRandomForestModel(OldAlgo.Classification, trees)
+ val newModel = RandomForestClassificationModel.fromOld(oldModel)
+
+ // Save model, load it back, and compare.
+ try {
+ newModel.save(sc, path)
+ val sameNewModel = RandomForestClassificationModel.load(sc, path)
+ TreeTests.checkEqual(newModel, sameNewModel)
+ } finally {
+ Utils.deleteRecursively(tempDir)
+ }
+ }
+ */
+}
+
+private object RandomForestClassifierSuite {
+
+ /**
+ * Train 2 models on the given dataset, one using the old API and one using the new API.
+ * Convert the old model to the new format, compare them, and fail if they are not exactly equal.
+ */
+ def compareAPIs(
+ data: RDD[LabeledPoint],
+ rf: RandomForestClassifier,
+ categoricalFeatures: Map[Int, Int],
+ numClasses: Int): Unit = {
+ val oldStrategy =
+ rf.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, rf.getOldImpurity)
+ val oldModel = OldRandomForest.trainClassifier(
+ data, oldStrategy, rf.getNumTrees, rf.getFeatureSubsetStrategy, rf.getSeed.toInt)
+ val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses)
+ val newModel = rf.fit(newData)
+ // Use parent, fittingParamMap from newTree since these are not checked anyways.
+ val oldModelAsNew = RandomForestClassificationModel.fromOld(oldModel, newModel.parent,
+ newModel.fittingParamMap, categoricalFeatures)
+ TreeTests.checkEqual(oldModelAsNew, newModel)
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala
new file mode 100644
index 0000000000000..eaee3443c1f23
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.sql.{Row, SQLContext}
+
+class IDFSuite extends FunSuite with MLlibTestSparkContext {
+
+ @transient var sqlContext: SQLContext = _
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ sqlContext = new SQLContext(sc)
+ }
+
+ def scaleDataWithIDF(dataSet: Array[Vector], model: Vector): Array[Vector] = {
+ dataSet.map {
+ case data: DenseVector =>
+ val res = data.toArray.zip(model.toArray).map { case (x, y) => x * y }
+ Vectors.dense(res)
+ case data: SparseVector =>
+ val res = data.indices.zip(data.values).map { case (id, value) =>
+ (id, value * model(id))
+ }
+ Vectors.sparse(data.size, res)
+ }
+ }
+
+ test("compute IDF with default parameter") {
+ val numOfFeatures = 4
+ val data = Array(
+ Vectors.sparse(numOfFeatures, Array(1, 3), Array(1.0, 2.0)),
+ Vectors.dense(0.0, 1.0, 2.0, 3.0),
+ Vectors.sparse(numOfFeatures, Array(1), Array(1.0))
+ )
+ val numOfData = data.size
+ val idf = Vectors.dense(Array(0, 3, 1, 2).map { x =>
+ math.log((numOfData + 1.0) / (x + 1.0))
+ })
+ val expected = scaleDataWithIDF(data, idf)
+
+ val df = sqlContext.createDataFrame(data.zip(expected)).toDF("features", "expected")
+
+ val idfModel = new IDF()
+ .setInputCol("features")
+ .setOutputCol("idfValue")
+ .fit(df)
+
+ idfModel.transform(df).select("idfValue", "expected").collect().foreach {
+ case Row(x: Vector, y: Vector) =>
+ assert(x ~== y absTol 1e-5, "Transformed vector is different with expected vector.")
+ }
+ }
+
+ test("compute IDF with setter") {
+ val numOfFeatures = 4
+ val data = Array(
+ Vectors.sparse(numOfFeatures, Array(1, 3), Array(1.0, 2.0)),
+ Vectors.dense(0.0, 1.0, 2.0, 3.0),
+ Vectors.sparse(numOfFeatures, Array(1), Array(1.0))
+ )
+ val numOfData = data.size
+ val idf = Vectors.dense(Array(0, 3, 1, 2).map { x =>
+ if (x > 0) math.log((numOfData + 1.0) / (x + 1.0)) else 0
+ })
+ val expected = scaleDataWithIDF(data, idf)
+
+ val df = sqlContext.createDataFrame(data.zip(expected)).toDF("features", "expected")
+
+ val idfModel = new IDF()
+ .setInputCol("features")
+ .setOutputCol("idfValue")
+ .setMinDocFreq(1)
+ .fit(df)
+
+ idfModel.transform(df).select("idfValue", "expected").collect().foreach {
+ case Row(x: Vector, y: Vector) =>
+ assert(x ~== y absTol 1e-5, "Transformed vector is different with expected vector.")
+ }
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala
new file mode 100644
index 0000000000000..c1d64fba0aa8f
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.sql.{Row, SQLContext}
+import org.scalatest.exceptions.TestFailedException
+
+class PolynomialExpansionSuite extends FunSuite with MLlibTestSparkContext {
+
+ @transient var sqlContext: SQLContext = _
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ sqlContext = new SQLContext(sc)
+ }
+
+ test("Polynomial expansion with default parameter") {
+ val data = Array(
+ Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))),
+ Vectors.dense(-2.0, 2.3),
+ Vectors.dense(0.0, 0.0, 0.0),
+ Vectors.dense(0.6, -1.1, -3.0),
+ Vectors.sparse(3, Seq())
+ )
+
+ val twoDegreeExpansion: Array[Vector] = Array(
+ Vectors.sparse(9, Array(0, 1, 2, 3, 4), Array(-2.0, 4.0, 2.3, -4.6, 5.29)),
+ Vectors.dense(-2.0, 4.0, 2.3, -4.6, 5.29),
+ Vectors.dense(new Array[Double](9)),
+ Vectors.dense(0.6, 0.36, -1.1, -0.66, 1.21, -3.0, -1.8, 3.3, 9.0),
+ Vectors.sparse(9, Array.empty, Array.empty))
+
+ val df = sqlContext.createDataFrame(data.zip(twoDegreeExpansion)).toDF("features", "expected")
+
+ val polynomialExpansion = new PolynomialExpansion()
+ .setInputCol("features")
+ .setOutputCol("polyFeatures")
+
+ polynomialExpansion.transform(df).select("polyFeatures", "expected").collect().foreach {
+ case Row(expanded: DenseVector, expected: DenseVector) =>
+ assert(expanded ~== expected absTol 1e-1)
+ case Row(expanded: SparseVector, expected: SparseVector) =>
+ assert(expanded ~== expected absTol 1e-1)
+ case _ =>
+ throw new TestFailedException("Unmatched data types after polynomial expansion", 0)
+ }
+ }
+
+ test("Polynomial expansion with setter") {
+ val data = Array(
+ Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))),
+ Vectors.dense(-2.0, 2.3),
+ Vectors.dense(0.0, 0.0, 0.0),
+ Vectors.dense(0.6, -1.1, -3.0),
+ Vectors.sparse(3, Seq())
+ )
+
+ val threeDegreeExpansion: Array[Vector] = Array(
+ Vectors.sparse(19, Array(0, 1, 2, 3, 4, 5, 6, 7, 8),
+ Array(-2.0, 4.0, -8.0, 2.3, -4.6, 9.2, 5.29, -10.58, 12.17)),
+ Vectors.dense(-2.0, 4.0, -8.0, 2.3, -4.6, 9.2, 5.29, -10.58, 12.17),
+ Vectors.dense(new Array[Double](19)),
+ Vectors.dense(0.6, 0.36, 0.216, -1.1, -0.66, -0.396, 1.21, 0.726, -1.331, -3.0, -1.8,
+ -1.08, 3.3, 1.98, -3.63, 9.0, 5.4, -9.9, -27.0),
+ Vectors.sparse(19, Array.empty, Array.empty))
+
+ val df = sqlContext.createDataFrame(data.zip(threeDegreeExpansion)).toDF("features", "expected")
+
+ val polynomialExpansion = new PolynomialExpansion()
+ .setInputCol("features")
+ .setOutputCol("polyFeatures")
+ .setDegree(3)
+
+ polynomialExpansion.transform(df).select("polyFeatures", "expected").collect().foreach {
+ case Row(expanded: DenseVector, expected: DenseVector) =>
+ assert(expanded ~== expected absTol 1e-1)
+ case Row(expanded: SparseVector, expected: SparseVector) =>
+ assert(expanded ~== expected absTol 1e-1)
+ case _ =>
+ throw new TestFailedException("Unmatched data types after polynomial expansion", 0)
+ }
+ }
+}
+
diff --git a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala
index 2e57d4ce37f1d..1505ad872536b 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala
@@ -23,8 +23,7 @@ import org.scalatest.FunSuite
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute}
-import org.apache.spark.ml.impl.tree._
-import org.apache.spark.ml.tree.{DecisionTreeModel, InternalNode, LeafNode, Node}
+import org.apache.spark.ml.tree._
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{SQLContext, DataFrame}
@@ -111,22 +110,19 @@ private[ml] object TreeTests extends FunSuite {
}
}
- // TODO: Reinstate after adding ensembles
/**
* Check if the two models are exactly the same.
* If the models are not equal, this throws an exception.
*/
- /*
def checkEqual(a: TreeEnsembleModel, b: TreeEnsembleModel): Unit = {
try {
- a.getTrees.zip(b.getTrees).foreach { case (treeA, treeB) =>
+ a.trees.zip(b.trees).foreach { case (treeA, treeB) =>
TreeTests.checkEqual(treeA, treeB)
}
- assert(a.getTreeWeights === b.getTreeWeights)
+ assert(a.treeWeights === b.treeWeights)
} catch {
case ex: Exception => throw new AssertionError(
"checkEqual failed since the two tree ensembles were not identical")
}
}
- */
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
index 0b40fe33fae9d..c87a171b4b229 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
@@ -66,7 +66,7 @@ class DecisionTreeRegressorSuite extends FunSuite with MLlibTestSparkContext {
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////
- // TODO: test("model save/load")
+ // TODO: test("model save/load") SPARK-6725
}
private[ml] object DecisionTreeRegressorSuite extends FunSuite {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
new file mode 100644
index 0000000000000..4aec36948ac92
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
@@ -0,0 +1,137 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.regression
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.ml.impl.TreeTests
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT}
+import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.DataFrame
+
+
+/**
+ * Test suite for [[GBTRegressor]].
+ */
+class GBTRegressorSuite extends FunSuite with MLlibTestSparkContext {
+
+ import GBTRegressorSuite.compareAPIs
+
+ // Combinations for estimators, learning rates and subsamplingRate
+ private val testCombinations =
+ Array((10, 1.0, 1.0), (10, 0.1, 1.0), (10, 0.5, 0.75), (10, 0.1, 0.75))
+
+ private var data: RDD[LabeledPoint] = _
+ private var trainData: RDD[LabeledPoint] = _
+ private var validationData: RDD[LabeledPoint] = _
+
+ override def beforeAll() {
+ super.beforeAll()
+ data = sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 10, 100), 2)
+ trainData =
+ sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 120), 2)
+ validationData =
+ sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 80), 2)
+ }
+
+ test("Regression with continuous features: SquaredError") {
+ val categoricalFeatures = Map.empty[Int, Int]
+ GBTRegressor.supportedLossTypes.foreach { loss =>
+ testCombinations.foreach {
+ case (maxIter, learningRate, subsamplingRate) =>
+ val gbt = new GBTRegressor()
+ .setMaxDepth(2)
+ .setSubsamplingRate(subsamplingRate)
+ .setLossType(loss)
+ .setMaxIter(maxIter)
+ .setStepSize(learningRate)
+ compareAPIs(data, None, gbt, categoricalFeatures)
+ }
+ }
+ }
+
+ // TODO: Reinstate test once runWithValidation is implemented SPARK-7132
+ /*
+ test("runWithValidation stops early and performs better on a validation dataset") {
+ val categoricalFeatures = Map.empty[Int, Int]
+ // Set maxIter large enough so that it stops early.
+ val maxIter = 20
+ GBTRegressor.supportedLossTypes.foreach { loss =>
+ val gbt = new GBTRegressor()
+ .setMaxIter(maxIter)
+ .setMaxDepth(2)
+ .setLossType(loss)
+ .setValidationTol(0.0)
+ compareAPIs(trainData, None, gbt, categoricalFeatures)
+ compareAPIs(trainData, Some(validationData), gbt, categoricalFeatures)
+ }
+ }
+ */
+
+ /////////////////////////////////////////////////////////////////////////////
+ // Tests of model save/load
+ /////////////////////////////////////////////////////////////////////////////
+
+ // TODO: Reinstate test once save/load are implemented SPARK-6725
+ /*
+ test("model save/load") {
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+
+ val trees = Range(0, 3).map(_ => OldDecisionTreeSuite.createModel(OldAlgo.Regression)).toArray
+ val treeWeights = Array(0.1, 0.3, 1.1)
+ val oldModel = new OldGBTModel(OldAlgo.Regression, trees, treeWeights)
+ val newModel = GBTRegressionModel.fromOld(oldModel)
+
+ // Save model, load it back, and compare.
+ try {
+ newModel.save(sc, path)
+ val sameNewModel = GBTRegressionModel.load(sc, path)
+ TreeTests.checkEqual(newModel, sameNewModel)
+ } finally {
+ Utils.deleteRecursively(tempDir)
+ }
+ }
+ */
+}
+
+private object GBTRegressorSuite {
+
+ /**
+ * Train 2 models on the given dataset, one using the old API and one using the new API.
+ * Convert the old model to the new format, compare them, and fail if they are not exactly equal.
+ */
+ def compareAPIs(
+ data: RDD[LabeledPoint],
+ validationData: Option[RDD[LabeledPoint]],
+ gbt: GBTRegressor,
+ categoricalFeatures: Map[Int, Int]): Unit = {
+ val oldBoostingStrategy = gbt.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression)
+ val oldGBT = new OldGBT(oldBoostingStrategy)
+ val oldModel = oldGBT.run(data)
+ val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0)
+ val newModel = gbt.fit(newData)
+ // Use parent, fittingParamMap from newTree since these are not checked anyways.
+ val oldModelAsNew = GBTRegressionModel.fromOld(oldModel, newModel.parent,
+ newModel.fittingParamMap, categoricalFeatures)
+ TreeTests.checkEqual(oldModelAsNew, newModel)
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
new file mode 100644
index 0000000000000..c6dc1cc29b6ff
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.regression
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.ml.impl.TreeTests
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest}
+import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.DataFrame
+
+
+/**
+ * Test suite for [[RandomForestRegressor]].
+ */
+class RandomForestRegressorSuite extends FunSuite with MLlibTestSparkContext {
+
+ import RandomForestRegressorSuite.compareAPIs
+
+ private var orderedLabeledPoints50_1000: RDD[LabeledPoint] = _
+
+ override def beforeAll() {
+ super.beforeAll()
+ orderedLabeledPoints50_1000 =
+ sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000))
+ }
+
+ /////////////////////////////////////////////////////////////////////////////
+ // Tests calling train()
+ /////////////////////////////////////////////////////////////////////////////
+
+ def regressionTestWithContinuousFeatures(rf: RandomForestRegressor) {
+ val categoricalFeaturesInfo = Map.empty[Int, Int]
+ val newRF = rf
+ .setImpurity("variance")
+ .setMaxDepth(2)
+ .setMaxBins(10)
+ .setNumTrees(1)
+ .setFeatureSubsetStrategy("auto")
+ .setSeed(123)
+ compareAPIs(orderedLabeledPoints50_1000, newRF, categoricalFeaturesInfo)
+ }
+
+ test("Regression with continuous features:" +
+ " comparing DecisionTree vs. RandomForest(numTrees = 1)") {
+ val rf = new RandomForestRegressor()
+ regressionTestWithContinuousFeatures(rf)
+ }
+
+ test("Regression with continuous features and node Id cache :" +
+ " comparing DecisionTree vs. RandomForest(numTrees = 1)") {
+ val rf = new RandomForestRegressor()
+ .setCacheNodeIds(true)
+ regressionTestWithContinuousFeatures(rf)
+ }
+
+ /////////////////////////////////////////////////////////////////////////////
+ // Tests of model save/load
+ /////////////////////////////////////////////////////////////////////////////
+
+ // TODO: Reinstate test once save/load are implemented SPARK-6725
+ /*
+ test("model save/load") {
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+
+ val trees = Range(0, 3).map(_ => OldDecisionTreeSuite.createModel(OldAlgo.Regression)).toArray
+ val oldModel = new OldRandomForestModel(OldAlgo.Regression, trees)
+ val newModel = RandomForestRegressionModel.fromOld(oldModel)
+
+ // Save model, load it back, and compare.
+ try {
+ newModel.save(sc, path)
+ val sameNewModel = RandomForestRegressionModel.load(sc, path)
+ TreeTests.checkEqual(newModel, sameNewModel)
+ } finally {
+ Utils.deleteRecursively(tempDir)
+ }
+ }
+ */
+}
+
+private object RandomForestRegressorSuite extends FunSuite {
+
+ /**
+ * Train 2 models on the given dataset, one using the old API and one using the new API.
+ * Convert the old model to the new format, compare them, and fail if they are not exactly equal.
+ */
+ def compareAPIs(
+ data: RDD[LabeledPoint],
+ rf: RandomForestRegressor,
+ categoricalFeatures: Map[Int, Int]): Unit = {
+ val oldStrategy =
+ rf.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, rf.getOldImpurity)
+ val oldModel = OldRandomForest.trainRegressor(
+ data, oldStrategy, rf.getNumTrees, rf.getFeatureSubsetStrategy, rf.getSeed.toInt)
+ val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0)
+ val newModel = rf.fit(newData)
+ // Use parent, fittingParamMap from newTree since these are not checked anyways.
+ val oldModelAsNew = RandomForestRegressionModel.fromOld(oldModel, newModel.parent,
+ newModel.fittingParamMap, categoricalFeatures)
+ TreeTests.checkEqual(oldModelAsNew, newModel)
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
index 15de10fd13a19..cc747dabb9968 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
@@ -123,6 +123,14 @@ class LDASuite extends FunSuite with MLlibTestSparkContext {
assert(termVertexIds.map(i => LDA.index2term(i.toLong)) === termIds)
assert(termVertexIds.forall(i => LDA.isTermVertex((i.toLong, 0))))
}
+
+ test("setter alias") {
+ val lda = new LDA().setAlpha(2.0).setBeta(3.0)
+ assert(lda.getAlpha === 2.0)
+ assert(lda.getDocConcentration === 2.0)
+ assert(lda.getBeta === 3.0)
+ assert(lda.getTopicConcentration === 3.0)
+ }
}
private[clustering] object LDASuite {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index 249b8eae19b17..ce983eb27fa35 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -998,7 +998,7 @@ object DecisionTreeSuite extends FunSuite {
node.split = Some(new Split(feature = 1, threshold = 0.0, Categorical,
categories = List(0.0, 1.0)))
}
- // TODO: The information gain stats should be consistent with the same info stored in children.
+ // TODO: The information gain stats should be consistent with info in children: SPARK-7131
node.stats = Some(new InformationGainStats(gain = 0.1, impurity = 0.2,
leftImpurity = 0.3, rightImpurity = 0.4, new Predict(1.0, 0.4), new Predict(0.0, 0.6)))
node
@@ -1006,9 +1006,9 @@ object DecisionTreeSuite extends FunSuite {
/**
* Create a tree model. This is deterministic and contains a variety of node and feature types.
- * TODO: Update this to be a correct tree (with matching probabilities, impurities, etc.)
+ * TODO: Update to be a correct tree (with matching probabilities, impurities, etc.): SPARK-7131
*/
- private[mllib] def createModel(algo: Algo): DecisionTreeModel = {
+ private[spark] def createModel(algo: Algo): DecisionTreeModel = {
val topNode = createInternalNode(id = 1, Continuous)
val (node2, node3) = (createLeafNode(id = 2), createInternalNode(id = 3, Categorical))
val (node6, node7) = (createLeafNode(id = 6), createLeafNode(id = 7))
diff --git a/pom.xml b/pom.xml
index bcc2f57f1af5d..9fbce1d639d8b 100644
--- a/pom.xml
+++ b/pom.xml
@@ -146,7 +146,7 @@
0.7.11.8.31.1.0
- 4.2.6
+ 4.3.23.4.1${project.build.directory}/spark-test-classpath.txt2.10.4
@@ -420,6 +420,16 @@
jsr3051.3.9
+
+ org.apache.httpcomponents
+ httpclient
+ ${commons.httpclient.version}
+
+
+ org.apache.httpcomponents
+ httpcore
+ ${commons.httpclient.version}
+ org.seleniumhq.seleniumselenium-java
@@ -1735,9 +1745,9 @@
scala-2.11
- 2.11.2
+ 2.11.62.11
- 2.12
+ 2.12.1jline
diff --git a/python/pyspark/mllib/fpm.py b/python/pyspark/mllib/fpm.py
index 628ccc01cf3cc..d8df02bdbaba9 100644
--- a/python/pyspark/mllib/fpm.py
+++ b/python/pyspark/mllib/fpm.py
@@ -15,6 +15,10 @@
# limitations under the License.
#
+import numpy
+from numpy import array
+from collections import namedtuple
+
from pyspark import SparkContext
from pyspark.rdd import ignore_unicode_prefix
from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, inherit_doc
@@ -36,14 +40,14 @@ class FPGrowthModel(JavaModelWrapper):
>>> rdd = sc.parallelize(data, 2)
>>> model = FPGrowth.train(rdd, 0.6, 2)
>>> sorted(model.freqItemsets().collect())
- [([u'a'], 4), ([u'c'], 3), ([u'c', u'a'], 3)]
+ [FreqItemset(items=[u'a'], freq=4), FreqItemset(items=[u'c'], freq=3), ...
"""
def freqItemsets(self):
"""
- Get the frequent itemsets of this model
+ Returns the frequent itemsets of this model.
"""
- return self.call("getFreqItemsets")
+ return self.call("getFreqItemsets").map(lambda x: (FPGrowth.FreqItemset(x[0], x[1])))
class FPGrowth(object):
@@ -67,6 +71,11 @@ def train(cls, data, minSupport=0.3, numPartitions=-1):
model = callMLlibFunc("trainFPGrowthModel", data, float(minSupport), int(numPartitions))
return FPGrowthModel(model)
+ class FreqItemset(namedtuple("FreqItemset", ["items", "freq"])):
+ """
+ Represents an (items, freq) tuple.
+ """
+
def _test():
import doctest
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index ca9bf8efb945c..4759f5fe783ad 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -452,6 +452,20 @@ def columns(self):
"""
return [f.name for f in self.schema.fields]
+ @ignore_unicode_prefix
+ def alias(self, alias):
+ """Returns a new :class:`DataFrame` with an alias set.
+
+ >>> from pyspark.sql.functions import *
+ >>> df_as1 = df.alias("df_as1")
+ >>> df_as2 = df.alias("df_as2")
+ >>> joined_df = df_as1.join(df_as2, col("df_as1.name") == col("df_as2.name"), 'inner')
+ >>> joined_df.select(col("df_as1.name"), col("df_as2.name"), col("df_as2.age")).collect()
+ [Row(name=u'Alice', name=u'Alice', age=2), Row(name=u'Bob', name=u'Bob', age=5)]
+ """
+ assert isinstance(alias, basestring), "alias should be a string"
+ return DataFrame(getattr(self._jdf, "as")(alias), self.sql_ctx)
+
@ignore_unicode_prefix
def join(self, other, joinExprs=None, joinType=None):
"""Joins with another :class:`DataFrame`, using the given join expression.
@@ -459,16 +473,23 @@ def join(self, other, joinExprs=None, joinType=None):
The following performs a full outer join between ``df1`` and ``df2``.
:param other: Right side of the join
- :param joinExprs: Join expression
+ :param joinExprs: a string for join column name, or a join expression (Column).
+ If joinExprs is a string indicating the name of the join column,
+ the column must exist on both sides, and this performs an inner equi-join.
:param joinType: str, default 'inner'.
One of `inner`, `outer`, `left_outer`, `right_outer`, `semijoin`.
>>> df.join(df2, df.name == df2.name, 'outer').select(df.name, df2.height).collect()
[Row(name=None, height=80), Row(name=u'Alice', height=None), Row(name=u'Bob', height=85)]
+
+ >>> df.join(df2, 'name').select(df.name, df2.height).collect()
+ [Row(name=u'Bob', height=85)]
"""
if joinExprs is None:
jdf = self._jdf.join(other._jdf)
+ elif isinstance(joinExprs, basestring):
+ jdf = self._jdf.join(other._jdf, joinExprs)
else:
assert isinstance(joinExprs, Column), "joinExprs should be Column"
if joinType is None:
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index bb47923f24b82..f48b7b5d10af7 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -75,6 +75,20 @@ def _(col):
__all__.sort()
+def approxCountDistinct(col, rsd=None):
+ """Returns a new :class:`Column` for approximate distinct count of ``col``.
+
+ >>> df.agg(approxCountDistinct(df.age).alias('c')).collect()
+ [Row(c=2)]
+ """
+ sc = SparkContext._active_spark_context
+ if rsd is None:
+ jc = sc._jvm.functions.approxCountDistinct(_to_java_column(col))
+ else:
+ jc = sc._jvm.functions.approxCountDistinct(_to_java_column(col), rsd)
+ return Column(jc)
+
+
def countDistinct(col, *cols):
"""Returns a new :class:`Column` for distinct count of ``col`` or ``cols``.
@@ -89,18 +103,16 @@ def countDistinct(col, *cols):
return Column(jc)
-def approxCountDistinct(col, rsd=None):
- """Returns a new :class:`Column` for approximate distinct count of ``col``.
+def sparkPartitionId():
+ """Returns a column for partition ID of the Spark task.
- >>> df.agg(approxCountDistinct(df.age).alias('c')).collect()
- [Row(c=2)]
+ Note that this is indeterministic because it depends on data partitioning and task scheduling.
+
+ >>> df.repartition(1).select(sparkPartitionId().alias("pid")).collect()
+ [Row(pid=0), Row(pid=0)]
"""
sc = SparkContext._active_spark_context
- if rsd is None:
- jc = sc._jvm.functions.approxCountDistinct(_to_java_column(col))
- else:
- jc = sc._jvm.functions.approxCountDistinct(_to_java_column(col), rsd)
- return Column(jc)
+ return Column(sc._jvm.functions.sparkPartitionId())
class UserDefinedFunction(object):
diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkIMain.scala
index 1bb62c84abddc..1cb910f376060 100644
--- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkIMain.scala
+++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkIMain.scala
@@ -1129,7 +1129,7 @@ class SparkIMain(@BeanProperty val factory: ScriptEngineFactory, initialSettings
def apply(line: String): Result = debugging(s"""parse("$line")""") {
var isIncomplete = false
- currentRun.reporting.withIncompleteHandler((_, _) => isIncomplete = true) {
+ currentRun.parsing.withIncompleteHandler((_, _) => isIncomplete = true) {
reporter.reset()
val trees = newUnitParser(line).parseStats()
if (reporter.hasErrors) Error
diff --git a/sql/README.md b/sql/README.md
index 237620e3fa808..46aec7cef7984 100644
--- a/sql/README.md
+++ b/sql/README.md
@@ -12,7 +12,10 @@ Spark SQL is broken up into four subprojects:
Other dependencies for developers
---------------------------------
-In order to create new hive test cases , you will need to set several environmental variables.
+In order to create new hive test cases (i.e. a test suite based on `HiveComparisonTest`),
+you will need to setup your development environment based on the following instructions.
+
+If you are working with Hive 0.12.0, you will need to set several environmental variables as follows.
```
export HIVE_HOME="/hive/build/dist"
@@ -20,6 +23,24 @@ export HIVE_DEV_HOME="/hive/"
export HADOOP_HOME="/hadoop-1.0.4"
```
+If you are working with Hive 0.13.1, the following steps are needed:
+
+1. Download Hive's [0.13.1](https://hive.apache.org/downloads.html) and set `HIVE_HOME` with `export HIVE_HOME=""`. Please do not set `HIVE_DEV_HOME` (See [SPARK-4119](https://issues.apache.org/jira/browse/SPARK-4119)).
+2. Set `HADOOP_HOME` with `export HADOOP_HOME=""`
+3. Download all Hive 0.13.1a jars (Hive jars actually used by Spark) from [here](http://mvnrepository.com/artifact/org.spark-project.hive) and replace corresponding original 0.13.1 jars in `$HIVE_HOME/lib`.
+4. Download [Kryo 2.21 jar](http://mvnrepository.com/artifact/com.esotericsoftware.kryo/kryo/2.21) (Note: 2.22 jar does not work) and [Javolution 5.5.1 jar](http://mvnrepository.com/artifact/javolution/javolution/5.5.1) to `$HIVE_HOME/lib`.
+5. This step is optional. But, when generating golden answer files, if a Hive query fails and you find that Hive tries to talk to HDFS or you find weird runtime NPEs, set the following in your test suite...
+
+```
+val testTempDir = Utils.createTempDir()
+// We have to use kryo to let Hive correctly serialize some plans.
+sql("set hive.plan.serialization.format=kryo")
+// Explicitly set fs to local fs.
+sql(s"set fs.default.name=file://$testTempDir/")
+// Ask Hive to run jobs in-process as a single map and reduce task.
+sql("set mapred.job.tracker=local")
+```
+
Using the console
=================
An interactive scala console can be invoked by running `build/sbt hive/console`.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala
index 3823584287741..1f3c02478bd68 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala
@@ -32,7 +32,7 @@ private[sql] object KeywordNormalizer {
private[sql] abstract class AbstractSparkSQLParser
extends StandardTokenParsers with PackratParsers {
- def apply(input: String): LogicalPlan = {
+ def parse(input: String): LogicalPlan = {
// Initialize the Keywords.
lexical.initialize(reservedWords)
phrase(start)(new lexical.Scanner(input)) match {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index d9521953cad73..c52965507c715 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -17,8 +17,6 @@
package org.apache.spark.sql.catalyst
-import java.sql.Timestamp
-
import org.apache.spark.util.Utils
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
@@ -110,7 +108,7 @@ trait ScalaReflection {
StructField(p.name.toString, dataType, nullable)
}), nullable = true)
case t if t <:< typeOf[String] => Schema(StringType, nullable = true)
- case t if t <:< typeOf[Timestamp] => Schema(TimestampType, nullable = true)
+ case t if t <:< typeOf[java.sql.Timestamp] => Schema(TimestampType, nullable = true)
case t if t <:< typeOf[java.sql.Date] => Schema(DateType, nullable = true)
case t if t <:< typeOf[BigDecimal] => Schema(DecimalType.Unlimited, nullable = true)
case t if t <:< typeOf[java.math.BigDecimal] => Schema(DecimalType.Unlimited, nullable = true)
@@ -136,20 +134,20 @@ trait ScalaReflection {
def typeOfObject: PartialFunction[Any, DataType] = {
// The data type can be determined without ambiguity.
- case obj: BooleanType.JvmType => BooleanType
- case obj: BinaryType.JvmType => BinaryType
+ case obj: Boolean => BooleanType
+ case obj: Array[Byte] => BinaryType
case obj: String => StringType
- case obj: StringType.JvmType => StringType
- case obj: ByteType.JvmType => ByteType
- case obj: ShortType.JvmType => ShortType
- case obj: IntegerType.JvmType => IntegerType
- case obj: LongType.JvmType => LongType
- case obj: FloatType.JvmType => FloatType
- case obj: DoubleType.JvmType => DoubleType
+ case obj: UTF8String => StringType
+ case obj: Byte => ByteType
+ case obj: Short => ShortType
+ case obj: Int => IntegerType
+ case obj: Long => LongType
+ case obj: Float => FloatType
+ case obj: Double => DoubleType
case obj: java.sql.Date => DateType
case obj: java.math.BigDecimal => DecimalType.Unlimited
case obj: Decimal => DecimalType.Unlimited
- case obj: TimestampType.JvmType => TimestampType
+ case obj: java.sql.Timestamp => TimestampType
case null => NullType
// For other cases, there is no obvious mapping from the type of the given object to a
// Catalyst data type. A user should provide his/her specific rules
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index 4e5c64bb63c9f..5d5aba9644ff7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -296,7 +296,7 @@ package object dsl {
InsertIntoTable(
analysis.UnresolvedRelation(Seq(tableName)), Map.empty, logicalPlan, overwrite, false)
- def analyze: LogicalPlan = EliminateSubQueries(analysis.SimpleAnalyzer(logicalPlan))
+ def analyze: LogicalPlan = EliminateSubQueries(analysis.SimpleAnalyzer.execute(logicalPlan))
}
object plans { // scalastyle:ignore
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 566b34f7c3a6a..140ccd8d3796f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -346,7 +346,7 @@ case class MaxOf(left: Expression, right: Expression) extends Expression {
}
lazy val ordering = left.dataType match {
- case i: NativeType => i.ordering.asInstanceOf[Ordering[Any]]
+ case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]]
case other => sys.error(s"Type $other does not support ordered operations")
}
@@ -391,7 +391,7 @@ case class MinOf(left: Expression, right: Expression) extends Expression {
}
lazy val ordering = left.dataType match {
- case i: NativeType => i.ordering.asInstanceOf[Ordering[Any]]
+ case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]]
case other => sys.error(s"Type $other does not support ordered operations")
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index be2c101d63a63..dbc92fb93e95e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -98,11 +98,11 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
})
/** Generates the requested evaluator binding the given expression(s) to the inputSchema. */
- def apply(expressions: InType, inputSchema: Seq[Attribute]): OutType =
- apply(bind(expressions, inputSchema))
+ def generate(expressions: InType, inputSchema: Seq[Attribute]): OutType =
+ generate(bind(expressions, inputSchema))
/** Generates the requested evaluator given already bound expression(s). */
- def apply(expressions: InType): OutType = cache.get(canonicalize(expressions))
+ def generate(expressions: InType): OutType = cache.get(canonicalize(expressions))
/**
* Returns a term name that is unique within this instance of a `CodeGenerator`.
@@ -279,7 +279,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
org.apache.spark.sql.types.UTF8String(${eval.primitiveTerm}.toString)
""".children
- case EqualTo(e1: BinaryType, e2: BinaryType) =>
+ case EqualTo(e1 @ BinaryType(), e2 @ BinaryType()) =>
(e1, e2).evaluateAs (BooleanType) {
case (eval1, eval2) =>
q"""
@@ -623,7 +623,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
protected def getColumn(inputRow: TermName, dataType: DataType, ordinal: Int) = {
dataType match {
case StringType => q"$inputRow($ordinal).asInstanceOf[org.apache.spark.sql.types.UTF8String]"
- case dt @ NativeType() => q"$inputRow.${accessorForType(dt)}($ordinal)"
+ case dt: DataType if isNativeType(dt) => q"$inputRow.${accessorForType(dt)}($ordinal)"
case _ => q"$inputRow.apply($ordinal).asInstanceOf[${termForType(dataType)}]"
}
}
@@ -635,7 +635,8 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
value: TermName) = {
dataType match {
case StringType => q"$destinationRow.update($ordinal, $value)"
- case dt @ NativeType() => q"$destinationRow.${mutatorForType(dt)}($ordinal, $value)"
+ case dt: DataType if isNativeType(dt) =>
+ q"$destinationRow.${mutatorForType(dt)}($ordinal, $value)"
case _ => q"$destinationRow.update($ordinal, $value)"
}
}
@@ -675,7 +676,18 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
}
protected def termForType(dt: DataType) = dt match {
- case n: NativeType => n.tag
+ case n: AtomicType => n.tag
case _ => typeTag[Any]
}
+
+ /**
+ * List of data types that have special accessors and setters in [[Row]].
+ */
+ protected val nativeTypes =
+ Seq(IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType)
+
+ /**
+ * Returns true if the data type has a special accessor and setter in [[Row]].
+ */
+ protected def isNativeType(dt: DataType) = nativeTypes.contains(dt)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
index a419fd7ecb39b..840260703ab74 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
@@ -30,7 +30,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
val mutableRowName = newTermName("mutableRow")
protected def canonicalize(in: Seq[Expression]): Seq[Expression] =
- in.map(ExpressionCanonicalizer(_))
+ in.map(ExpressionCanonicalizer.execute)
protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] =
in.map(BindReferences.bindReference(_, inputSchema))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
index fc2a2b60703e4..b129c0d898bb7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
@@ -30,7 +30,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] wit
import scala.reflect.runtime.universe._
protected def canonicalize(in: Seq[SortOrder]): Seq[SortOrder] =
- in.map(ExpressionCanonicalizer(_).asInstanceOf[SortOrder])
+ in.map(ExpressionCanonicalizer.execute(_).asInstanceOf[SortOrder])
protected def bind(in: Seq[SortOrder], inputSchema: Seq[Attribute]): Seq[SortOrder] =
in.map(BindReferences.bindReference(_, inputSchema))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala
index 2a0935c790cf3..40e163024360e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala
@@ -26,7 +26,7 @@ object GeneratePredicate extends CodeGenerator[Expression, (Row) => Boolean] {
import scala.reflect.runtime.{universe => ru}
import scala.reflect.runtime.universe._
- protected def canonicalize(in: Expression): Expression = ExpressionCanonicalizer(in)
+ protected def canonicalize(in: Expression): Expression = ExpressionCanonicalizer.execute(in)
protected def bind(in: Expression, inputSchema: Seq[Attribute]): Expression =
BindReferences.bindReference(in, inputSchema)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
index 6f572ff959fb4..584f938445c8c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
@@ -31,7 +31,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
import scala.reflect.runtime.universe._
protected def canonicalize(in: Seq[Expression]): Seq[Expression] =
- in.map(ExpressionCanonicalizer(_))
+ in.map(ExpressionCanonicalizer.execute)
protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] =
in.map(BindReferences.bindReference(_, inputSchema))
@@ -109,7 +109,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
q"override def update(i: Int, value: Any): Unit = { ..$cases; $accessorFailure }"
}
- val specificAccessorFunctions = NativeType.all.map { dataType =>
+ val specificAccessorFunctions = nativeTypes.map { dataType =>
val ifStatements = expressions.zipWithIndex.flatMap {
// getString() is not used by expressions
case (e, i) if e.dataType == dataType && dataType != StringType =>
@@ -135,7 +135,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
}
}
- val specificMutatorFunctions = NativeType.all.map { dataType =>
+ val specificMutatorFunctions = nativeTypes.map { dataType =>
val ifStatements = expressions.zipWithIndex.flatMap {
// setString() is not used by expressions
case (e, i) if e.dataType == dataType && dataType != StringType =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index fcd6352079b4d..9cb00cb2732ff 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -20,13 +20,13 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-import org.apache.spark.sql.types.{DataType, BinaryType, BooleanType, NativeType}
+import org.apache.spark.sql.types.{DataType, BinaryType, BooleanType, AtomicType}
object InterpretedPredicate {
- def apply(expression: Expression, inputSchema: Seq[Attribute]): (Row => Boolean) =
- apply(BindReferences.bindReference(expression, inputSchema))
+ def create(expression: Expression, inputSchema: Seq[Attribute]): (Row => Boolean) =
+ create(BindReferences.bindReference(expression, inputSchema))
- def apply(expression: Expression): (Row => Boolean) = {
+ def create(expression: Expression): (Row => Boolean) = {
(r: Row) => expression.eval(r).asInstanceOf[Boolean]
}
}
@@ -211,7 +211,7 @@ case class LessThan(left: Expression, right: Expression) extends BinaryCompariso
s"Types do not match ${left.dataType} != ${right.dataType}")
}
left.dataType match {
- case i: NativeType => i.ordering.asInstanceOf[Ordering[Any]]
+ case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]]
case other => sys.error(s"Type $other does not support ordered operations")
}
}
@@ -240,7 +240,7 @@ case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryCo
s"Types do not match ${left.dataType} != ${right.dataType}")
}
left.dataType match {
- case i: NativeType => i.ordering.asInstanceOf[Ordering[Any]]
+ case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]]
case other => sys.error(s"Type $other does not support ordered operations")
}
}
@@ -269,7 +269,7 @@ case class GreaterThan(left: Expression, right: Expression) extends BinaryCompar
s"Types do not match ${left.dataType} != ${right.dataType}")
}
left.dataType match {
- case i: NativeType => i.ordering.asInstanceOf[Ordering[Any]]
+ case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]]
case other => sys.error(s"Type $other does not support ordered operations")
}
}
@@ -298,7 +298,7 @@ case class GreaterThanOrEqual(left: Expression, right: Expression) extends Binar
s"Types do not match ${left.dataType} != ${right.dataType}")
}
left.dataType match {
- case i: NativeType => i.ordering.asInstanceOf[Ordering[Any]]
+ case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]]
case other => sys.error(s"Type $other does not support ordered operations")
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
index 981373477a4bc..5fd892c42e69c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.catalyst.expressions
-import org.apache.spark.sql.types.{UTF8String, DataType, StructType, NativeType}
+import org.apache.spark.sql.types.{UTF8String, DataType, StructType, AtomicType}
/**
* An extended interface to [[Row]] that allows the values for each column to be updated. Setting
@@ -227,9 +227,9 @@ class RowOrdering(ordering: Seq[SortOrder]) extends Ordering[Row] {
return if (order.direction == Ascending) 1 else -1
} else {
val comparison = order.dataType match {
- case n: NativeType if order.direction == Ascending =>
+ case n: AtomicType if order.direction == Ascending =>
n.ordering.asInstanceOf[Ordering[Any]].compare(left, right)
- case n: NativeType if order.direction == Descending =>
+ case n: AtomicType if order.direction == Descending =>
n.ordering.asInstanceOf[Ordering[Any]].reverse.compare(left, right)
case other => sys.error(s"Type $other does not support ordered operations")
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
index 9c8c643f7d17a..4574934d910db 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
@@ -92,7 +92,7 @@ object PhysicalOperation extends PredicateHelper {
}
def collectAliases(fields: Seq[Expression]): Map[Attribute, Expression] = fields.collect {
- case a @ Alias(child, _) => a.toAttribute.asInstanceOf[Attribute] -> child
+ case a @ Alias(child, _) => a.toAttribute -> child
}.toMap
def substitute(aliases: Map[Attribute, Expression])(expr: Expression): Expression = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala
index c441f0bf24d85..3f9858b0c4a43 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala
@@ -45,7 +45,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging {
* Executes the batches of rules defined by the subclass. The batches are executed serially
* using the defined execution strategy. Within each batch, rules are also executed serially.
*/
- def apply(plan: TreeType): TreeType = {
+ def execute(plan: TreeType): TreeType = {
var curPlan = plan
batches.foreach { batch =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala
new file mode 100644
index 0000000000000..b116163faccad
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.types
+
+import org.json4s.JsonDSL._
+
+import org.apache.spark.annotation.DeveloperApi
+
+
+object ArrayType {
+ /** Construct a [[ArrayType]] object with the given element type. The `containsNull` is true. */
+ def apply(elementType: DataType): ArrayType = ArrayType(elementType, containsNull = true)
+}
+
+
+/**
+ * :: DeveloperApi ::
+ * The data type for collections of multiple values.
+ * Internally these are represented as columns that contain a ``scala.collection.Seq``.
+ *
+ * Please use [[DataTypes.createArrayType()]] to create a specific instance.
+ *
+ * An [[ArrayType]] object comprises two fields, `elementType: [[DataType]]` and
+ * `containsNull: Boolean`. The field of `elementType` is used to specify the type of
+ * array elements. The field of `containsNull` is used to specify if the array has `null` values.
+ *
+ * @param elementType The data type of values.
+ * @param containsNull Indicates if values have `null` values
+ *
+ * @group dataType
+ */
+@DeveloperApi
+case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataType {
+
+ /** No-arg constructor for kryo. */
+ protected def this() = this(null, false)
+
+ private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = {
+ builder.append(
+ s"$prefix-- element: ${elementType.typeName} (containsNull = $containsNull)\n")
+ DataType.buildFormattedString(elementType, s"$prefix |", builder)
+ }
+
+ override private[sql] def jsonValue =
+ ("type" -> typeName) ~
+ ("elementType" -> elementType.jsonValue) ~
+ ("containsNull" -> containsNull)
+
+ /**
+ * The default size of a value of the ArrayType is 100 * the default size of the element type.
+ * (We assume that there are 100 elements).
+ */
+ override def defaultSize: Int = 100 * elementType.defaultSize
+
+ override def simpleString: String = s"array<${elementType.simpleString}>"
+
+ private[spark] override def asNullable: ArrayType =
+ ArrayType(elementType.asNullable, containsNull = true)
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala
new file mode 100644
index 0000000000000..a581a9e9468ef
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.types
+
+import scala.math.Ordering
+import scala.reflect.runtime.universe.typeTag
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.sql.catalyst.ScalaReflectionLock
+
+
+/**
+ * :: DeveloperApi ::
+ * The data type representing `Array[Byte]` values.
+ * Please use the singleton [[DataTypes.BinaryType]].
+ *
+ * @group dataType
+ */
+@DeveloperApi
+class BinaryType private() extends AtomicType {
+ // The companion object and this class is separated so the companion object also subclasses
+ // this type. Otherwise, the companion object would be of type "BinaryType$" in byte code.
+ // Defined with a private constructor so the companion object is the only possible instantiation.
+
+ private[sql] type InternalType = Array[Byte]
+
+ @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] }
+
+ private[sql] val ordering = new Ordering[InternalType] {
+ def compare(x: Array[Byte], y: Array[Byte]): Int = {
+ for (i <- 0 until x.length; if i < y.length) {
+ val res = x(i).compareTo(y(i))
+ if (res != 0) return res
+ }
+ x.length - y.length
+ }
+ }
+
+ /**
+ * The default size of a value of the BinaryType is 4096 bytes.
+ */
+ override def defaultSize: Int = 4096
+
+ private[spark] override def asNullable: BinaryType = this
+}
+
+
+case object BinaryType extends BinaryType
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BooleanType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BooleanType.scala
new file mode 100644
index 0000000000000..a7f228cefa57a
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BooleanType.scala
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.types
+
+import scala.math.Ordering
+import scala.reflect.runtime.universe.typeTag
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.sql.catalyst.ScalaReflectionLock
+
+
+/**
+ * :: DeveloperApi ::
+ * The data type representing `Boolean` values. Please use the singleton [[DataTypes.BooleanType]].
+ *
+ *@group dataType
+ */
+@DeveloperApi
+class BooleanType private() extends AtomicType {
+ // The companion object and this class is separated so the companion object also subclasses
+ // this type. Otherwise, the companion object would be of type "BooleanType$" in byte code.
+ // Defined with a private constructor so the companion object is the only possible instantiation.
+ private[sql] type InternalType = Boolean
+ @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] }
+ private[sql] val ordering = implicitly[Ordering[InternalType]]
+
+ /**
+ * The default size of a value of the BooleanType is 1 byte.
+ */
+ override def defaultSize: Int = 1
+
+ private[spark] override def asNullable: BooleanType = this
+}
+
+
+case object BooleanType extends BooleanType
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ByteType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ByteType.scala
new file mode 100644
index 0000000000000..4d8685796ec76
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ByteType.scala
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.types
+
+import scala.math.{Ordering, Integral, Numeric}
+import scala.reflect.runtime.universe.typeTag
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.sql.catalyst.ScalaReflectionLock
+
+
+/**
+ * :: DeveloperApi ::
+ * The data type representing `Byte` values. Please use the singleton [[DataTypes.ByteType]].
+ *
+ * @group dataType
+ */
+@DeveloperApi
+class ByteType private() extends IntegralType {
+ // The companion object and this class is separated so the companion object also subclasses
+ // this type. Otherwise, the companion object would be of type "ByteType$" in byte code.
+ // Defined with a private constructor so the companion object is the only possible instantiation.
+ private[sql] type InternalType = Byte
+ @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] }
+ private[sql] val numeric = implicitly[Numeric[Byte]]
+ private[sql] val integral = implicitly[Integral[Byte]]
+ private[sql] val ordering = implicitly[Ordering[InternalType]]
+
+ /**
+ * The default size of a value of the ByteType is 1 byte.
+ */
+ override def defaultSize: Int = 1
+
+ override def simpleString: String = "tinyint"
+
+ private[spark] override def asNullable: ByteType = this
+}
+
+case object ByteType extends ByteType
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
new file mode 100644
index 0000000000000..0992a7c311ee2
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
@@ -0,0 +1,385 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.types
+
+import scala.reflect.ClassTag
+import scala.reflect.runtime.universe.{TypeTag, runtimeMirror}
+import scala.util.parsing.combinator.RegexParsers
+
+import org.json4s._
+import org.json4s.JsonAST.JValue
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.sql.catalyst.ScalaReflectionLock
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.util.Utils
+
+
+/**
+ * :: DeveloperApi ::
+ * The base type of all Spark SQL data types.
+ *
+ * @group dataType
+ */
+@DeveloperApi
+abstract class DataType {
+ /**
+ * Enables matching against DataType for expressions:
+ * {{{
+ * case Cast(child @ BinaryType(), StringType) =>
+ * ...
+ * }}}
+ */
+ private[sql] def unapply(a: Expression): Boolean = a match {
+ case e: Expression if e.dataType == this => true
+ case _ => false
+ }
+
+ /**
+ * The default size of a value of this data type, used internally for size estimation.
+ */
+ def defaultSize: Int
+
+ /** Name of the type used in JSON serialization. */
+ def typeName: String = this.getClass.getSimpleName.stripSuffix("$").dropRight(4).toLowerCase
+
+ private[sql] def jsonValue: JValue = typeName
+
+ /** The compact JSON representation of this data type. */
+ def json: String = compact(render(jsonValue))
+
+ /** The pretty (i.e. indented) JSON representation of this data type. */
+ def prettyJson: String = pretty(render(jsonValue))
+
+ /** Readable string representation for the type. */
+ def simpleString: String = typeName
+
+ /**
+ * Check if `this` and `other` are the same data type when ignoring nullability
+ * (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`).
+ */
+ private[spark] def sameType(other: DataType): Boolean =
+ DataType.equalsIgnoreNullability(this, other)
+
+ /**
+ * Returns the same data type but set all nullability fields are true
+ * (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`).
+ */
+ private[spark] def asNullable: DataType
+}
+
+
+/**
+ * An internal type used to represent everything that is not null, UDTs, arrays, structs, and maps.
+ */
+protected[sql] abstract class AtomicType extends DataType {
+ private[sql] type InternalType
+ @transient private[sql] val tag: TypeTag[InternalType]
+ private[sql] val ordering: Ordering[InternalType]
+
+ @transient private[sql] val classTag = ScalaReflectionLock.synchronized {
+ val mirror = runtimeMirror(Utils.getSparkClassLoader)
+ ClassTag[InternalType](mirror.runtimeClass(tag.tpe))
+ }
+}
+
+
+/**
+ * :: DeveloperApi ::
+ * Numeric data types.
+ *
+ * @group dataType
+ */
+abstract class NumericType extends AtomicType {
+ // Unfortunately we can't get this implicitly as that breaks Spark Serialization. In order for
+ // implicitly[Numeric[JvmType]] to be valid, we have to change JvmType from a type variable to a
+ // type parameter and and add a numeric annotation (i.e., [JvmType : Numeric]). This gets
+ // desugared by the compiler into an argument to the objects constructor. This means there is no
+ // longer an no argument constructor and thus the JVM cannot serialize the object anymore.
+ private[sql] val numeric: Numeric[InternalType]
+}
+
+
+private[sql] object NumericType {
+ /**
+ * Enables matching against NumericType for expressions:
+ * {{{
+ * case Cast(child @ NumericType(), StringType) =>
+ * ...
+ * }}}
+ */
+ def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[NumericType]
+}
+
+
+private[sql] object IntegralType {
+ /**
+ * Enables matching against IntegralType for expressions:
+ * {{{
+ * case Cast(child @ IntegralType(), StringType) =>
+ * ...
+ * }}}
+ */
+ def unapply(a: Expression): Boolean = a match {
+ case e: Expression if e.dataType.isInstanceOf[IntegralType] => true
+ case _ => false
+ }
+}
+
+
+private[sql] abstract class IntegralType extends NumericType {
+ private[sql] val integral: Integral[InternalType]
+}
+
+
+private[sql] object FractionalType {
+ /**
+ * Enables matching against FractionalType for expressions:
+ * {{{
+ * case Cast(child @ FractionalType(), StringType) =>
+ * ...
+ * }}}
+ */
+ def unapply(a: Expression): Boolean = a match {
+ case e: Expression if e.dataType.isInstanceOf[FractionalType] => true
+ case _ => false
+ }
+}
+
+
+private[sql] abstract class FractionalType extends NumericType {
+ private[sql] val fractional: Fractional[InternalType]
+ private[sql] val asIntegral: Integral[InternalType]
+}
+
+
+object DataType {
+
+ def fromJson(json: String): DataType = parseDataType(parse(json))
+
+ @deprecated("Use DataType.fromJson instead", "1.2.0")
+ def fromCaseClassString(string: String): DataType = CaseClassStringParser(string)
+
+ private val nonDecimalNameToType = {
+ Seq(NullType, DateType, TimestampType, BinaryType,
+ IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType)
+ .map(t => t.typeName -> t).toMap
+ }
+
+ /** Given the string representation of a type, return its DataType */
+ private def nameToType(name: String): DataType = {
+ val FIXED_DECIMAL = """decimal\(\s*(\d+)\s*,\s*(\d+)\s*\)""".r
+ name match {
+ case "decimal" => DecimalType.Unlimited
+ case FIXED_DECIMAL(precision, scale) => DecimalType(precision.toInt, scale.toInt)
+ case other => nonDecimalNameToType(other)
+ }
+ }
+
+ private object JSortedObject {
+ def unapplySeq(value: JValue): Option[List[(String, JValue)]] = value match {
+ case JObject(seq) => Some(seq.toList.sortBy(_._1))
+ case _ => None
+ }
+ }
+
+ // NOTE: Map fields must be sorted in alphabetical order to keep consistent with the Python side.
+ private def parseDataType(json: JValue): DataType = json match {
+ case JString(name) =>
+ nameToType(name)
+
+ case JSortedObject(
+ ("containsNull", JBool(n)),
+ ("elementType", t: JValue),
+ ("type", JString("array"))) =>
+ ArrayType(parseDataType(t), n)
+
+ case JSortedObject(
+ ("keyType", k: JValue),
+ ("type", JString("map")),
+ ("valueContainsNull", JBool(n)),
+ ("valueType", v: JValue)) =>
+ MapType(parseDataType(k), parseDataType(v), n)
+
+ case JSortedObject(
+ ("fields", JArray(fields)),
+ ("type", JString("struct"))) =>
+ StructType(fields.map(parseStructField))
+
+ case JSortedObject(
+ ("class", JString(udtClass)),
+ ("pyClass", _),
+ ("sqlType", _),
+ ("type", JString("udt"))) =>
+ Class.forName(udtClass).newInstance().asInstanceOf[UserDefinedType[_]]
+ }
+
+ private def parseStructField(json: JValue): StructField = json match {
+ case JSortedObject(
+ ("metadata", metadata: JObject),
+ ("name", JString(name)),
+ ("nullable", JBool(nullable)),
+ ("type", dataType: JValue)) =>
+ StructField(name, parseDataType(dataType), nullable, Metadata.fromJObject(metadata))
+ // Support reading schema when 'metadata' is missing.
+ case JSortedObject(
+ ("name", JString(name)),
+ ("nullable", JBool(nullable)),
+ ("type", dataType: JValue)) =>
+ StructField(name, parseDataType(dataType), nullable)
+ }
+
+ private object CaseClassStringParser extends RegexParsers {
+ protected lazy val primitiveType: Parser[DataType] =
+ ( "StringType" ^^^ StringType
+ | "FloatType" ^^^ FloatType
+ | "IntegerType" ^^^ IntegerType
+ | "ByteType" ^^^ ByteType
+ | "ShortType" ^^^ ShortType
+ | "DoubleType" ^^^ DoubleType
+ | "LongType" ^^^ LongType
+ | "BinaryType" ^^^ BinaryType
+ | "BooleanType" ^^^ BooleanType
+ | "DateType" ^^^ DateType
+ | "DecimalType()" ^^^ DecimalType.Unlimited
+ | fixedDecimalType
+ | "TimestampType" ^^^ TimestampType
+ )
+
+ protected lazy val fixedDecimalType: Parser[DataType] =
+ ("DecimalType(" ~> "[0-9]+".r) ~ ("," ~> "[0-9]+".r <~ ")") ^^ {
+ case precision ~ scale => DecimalType(precision.toInt, scale.toInt)
+ }
+
+ protected lazy val arrayType: Parser[DataType] =
+ "ArrayType" ~> "(" ~> dataType ~ "," ~ boolVal <~ ")" ^^ {
+ case tpe ~ _ ~ containsNull => ArrayType(tpe, containsNull)
+ }
+
+ protected lazy val mapType: Parser[DataType] =
+ "MapType" ~> "(" ~> dataType ~ "," ~ dataType ~ "," ~ boolVal <~ ")" ^^ {
+ case t1 ~ _ ~ t2 ~ _ ~ valueContainsNull => MapType(t1, t2, valueContainsNull)
+ }
+
+ protected lazy val structField: Parser[StructField] =
+ ("StructField(" ~> "[a-zA-Z0-9_]*".r) ~ ("," ~> dataType) ~ ("," ~> boolVal <~ ")") ^^ {
+ case name ~ tpe ~ nullable =>
+ StructField(name, tpe, nullable = nullable)
+ }
+
+ protected lazy val boolVal: Parser[Boolean] =
+ ( "true" ^^^ true
+ | "false" ^^^ false
+ )
+
+ protected lazy val structType: Parser[DataType] =
+ "StructType\\([A-zA-z]*\\(".r ~> repsep(structField, ",") <~ "))" ^^ {
+ case fields => StructType(fields)
+ }
+
+ protected lazy val dataType: Parser[DataType] =
+ ( arrayType
+ | mapType
+ | structType
+ | primitiveType
+ )
+
+ /**
+ * Parses a string representation of a DataType.
+ *
+ * TODO: Generate parser as pickler...
+ */
+ def apply(asString: String): DataType = parseAll(dataType, asString) match {
+ case Success(result, _) => result
+ case failure: NoSuccess =>
+ throw new IllegalArgumentException(s"Unsupported dataType: $asString, $failure")
+ }
+ }
+
+ protected[types] def buildFormattedString(
+ dataType: DataType,
+ prefix: String,
+ builder: StringBuilder): Unit = {
+ dataType match {
+ case array: ArrayType =>
+ array.buildFormattedString(prefix, builder)
+ case struct: StructType =>
+ struct.buildFormattedString(prefix, builder)
+ case map: MapType =>
+ map.buildFormattedString(prefix, builder)
+ case _ =>
+ }
+ }
+
+ /**
+ * Compares two types, ignoring nullability of ArrayType, MapType, StructType.
+ */
+ private[types] def equalsIgnoreNullability(left: DataType, right: DataType): Boolean = {
+ (left, right) match {
+ case (ArrayType(leftElementType, _), ArrayType(rightElementType, _)) =>
+ equalsIgnoreNullability(leftElementType, rightElementType)
+ case (MapType(leftKeyType, leftValueType, _), MapType(rightKeyType, rightValueType, _)) =>
+ equalsIgnoreNullability(leftKeyType, rightKeyType) &&
+ equalsIgnoreNullability(leftValueType, rightValueType)
+ case (StructType(leftFields), StructType(rightFields)) =>
+ leftFields.length == rightFields.length &&
+ leftFields.zip(rightFields).forall { case (l, r) =>
+ l.name == r.name && equalsIgnoreNullability(l.dataType, r.dataType)
+ }
+ case (l, r) => l == r
+ }
+ }
+
+ /**
+ * Compares two types, ignoring compatible nullability of ArrayType, MapType, StructType.
+ *
+ * Compatible nullability is defined as follows:
+ * - If `from` and `to` are ArrayTypes, `from` has a compatible nullability with `to`
+ * if and only if `to.containsNull` is true, or both of `from.containsNull` and
+ * `to.containsNull` are false.
+ * - If `from` and `to` are MapTypes, `from` has a compatible nullability with `to`
+ * if and only if `to.valueContainsNull` is true, or both of `from.valueContainsNull` and
+ * `to.valueContainsNull` are false.
+ * - If `from` and `to` are StructTypes, `from` has a compatible nullability with `to`
+ * if and only if for all every pair of fields, `to.nullable` is true, or both
+ * of `fromField.nullable` and `toField.nullable` are false.
+ */
+ private[sql] def equalsIgnoreCompatibleNullability(from: DataType, to: DataType): Boolean = {
+ (from, to) match {
+ case (ArrayType(fromElement, fn), ArrayType(toElement, tn)) =>
+ (tn || !fn) && equalsIgnoreCompatibleNullability(fromElement, toElement)
+
+ case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) =>
+ (tn || !fn) &&
+ equalsIgnoreCompatibleNullability(fromKey, toKey) &&
+ equalsIgnoreCompatibleNullability(fromValue, toValue)
+
+ case (StructType(fromFields), StructType(toFields)) =>
+ fromFields.length == toFields.length &&
+ fromFields.zip(toFields).forall { case (fromField, toField) =>
+ fromField.name == toField.name &&
+ (toField.nullable || !fromField.nullable) &&
+ equalsIgnoreCompatibleNullability(fromField.dataType, toField.dataType)
+ }
+
+ case (fromDataType, toDataType) => fromDataType == toDataType
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala
index 5163f05879e42..04f3379afb38d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala
@@ -108,7 +108,7 @@ private[sql] object DataTypeParser {
override val lexical = new SqlLexical
}
- def apply(dataTypeString: String): DataType = dataTypeParser.toDataType(dataTypeString)
+ def parse(dataTypeString: String): DataType = dataTypeParser.toDataType(dataTypeString)
}
/** The exception thrown from the [[DataTypeParser]]. */
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateType.scala
new file mode 100644
index 0000000000000..03f0644bc784c
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateType.scala
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.types
+
+import scala.math.Ordering
+import scala.reflect.runtime.universe.typeTag
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.sql.catalyst.ScalaReflectionLock
+
+
+/**
+ * :: DeveloperApi ::
+ * The data type representing `java.sql.Date` values.
+ * Please use the singleton [[DataTypes.DateType]].
+ *
+ * @group dataType
+ */
+@DeveloperApi
+class DateType private() extends AtomicType {
+ // The companion object and this class is separated so the companion object also subclasses
+ // this type. Otherwise, the companion object would be of type "DateType$" in byte code.
+ // Defined with a private constructor so the companion object is the only possible instantiation.
+ private[sql] type InternalType = Int
+
+ @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] }
+
+ private[sql] val ordering = implicitly[Ordering[InternalType]]
+
+ /**
+ * The default size of a value of the DateType is 4 bytes.
+ */
+ override def defaultSize: Int = 4
+
+ private[spark] override def asNullable: DateType = this
+}
+
+
+case object DateType extends DateType
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
new file mode 100644
index 0000000000000..0f8cecd28f7df
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
@@ -0,0 +1,110 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.types
+
+import scala.reflect.runtime.universe.typeTag
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.sql.catalyst.ScalaReflectionLock
+import org.apache.spark.sql.catalyst.expressions.Expression
+
+
+/** Precision parameters for a Decimal */
+case class PrecisionInfo(precision: Int, scale: Int)
+
+
+/**
+ * :: DeveloperApi ::
+ * The data type representing `java.math.BigDecimal` values.
+ * A Decimal that might have fixed precision and scale, or unlimited values for these.
+ *
+ * Please use [[DataTypes.createDecimalType()]] to create a specific instance.
+ *
+ * @group dataType
+ */
+@DeveloperApi
+case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalType {
+
+ /** No-arg constructor for kryo. */
+ protected def this() = this(null)
+
+ private[sql] type InternalType = Decimal
+ @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] }
+ private[sql] val numeric = Decimal.DecimalIsFractional
+ private[sql] val fractional = Decimal.DecimalIsFractional
+ private[sql] val ordering = Decimal.DecimalIsFractional
+ private[sql] val asIntegral = Decimal.DecimalAsIfIntegral
+
+ def precision: Int = precisionInfo.map(_.precision).getOrElse(-1)
+
+ def scale: Int = precisionInfo.map(_.scale).getOrElse(-1)
+
+ override def typeName: String = precisionInfo match {
+ case Some(PrecisionInfo(precision, scale)) => s"decimal($precision,$scale)"
+ case None => "decimal"
+ }
+
+ override def toString: String = precisionInfo match {
+ case Some(PrecisionInfo(precision, scale)) => s"DecimalType($precision,$scale)"
+ case None => "DecimalType()"
+ }
+
+ /**
+ * The default size of a value of the DecimalType is 4096 bytes.
+ */
+ override def defaultSize: Int = 4096
+
+ override def simpleString: String = precisionInfo match {
+ case Some(PrecisionInfo(precision, scale)) => s"decimal($precision,$scale)"
+ case None => "decimal(10,0)"
+ }
+
+ private[spark] override def asNullable: DecimalType = this
+}
+
+
+/** Extra factory methods and pattern matchers for Decimals */
+object DecimalType {
+ val Unlimited: DecimalType = DecimalType(None)
+
+ object Fixed {
+ def unapply(t: DecimalType): Option[(Int, Int)] =
+ t.precisionInfo.map(p => (p.precision, p.scale))
+ }
+
+ object Expression {
+ def unapply(e: Expression): Option[(Int, Int)] = e.dataType match {
+ case t: DecimalType => t.precisionInfo.map(p => (p.precision, p.scale))
+ case _ => None
+ }
+ }
+
+ def apply(): DecimalType = Unlimited
+
+ def apply(precision: Int, scale: Int): DecimalType =
+ DecimalType(Some(PrecisionInfo(precision, scale)))
+
+ def unapply(t: DataType): Boolean = t.isInstanceOf[DecimalType]
+
+ def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[DecimalType]
+
+ def isFixed(dataType: DataType): Boolean = dataType match {
+ case DecimalType.Fixed(_, _) => true
+ case _ => false
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala
new file mode 100644
index 0000000000000..66766623213c9
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.types
+
+import scala.math.{Ordering, Fractional, Numeric}
+import scala.math.Numeric.DoubleAsIfIntegral
+import scala.reflect.runtime.universe.typeTag
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.sql.catalyst.ScalaReflectionLock
+
+/**
+ * :: DeveloperApi ::
+ * The data type representing `Double` values. Please use the singleton [[DataTypes.DoubleType]].
+ *
+ * @group dataType
+ */
+@DeveloperApi
+class DoubleType private() extends FractionalType {
+ // The companion object and this class is separated so the companion object also subclasses
+ // this type. Otherwise, the companion object would be of type "DoubleType$" in byte code.
+ // Defined with a private constructor so the companion object is the only possible instantiation.
+ private[sql] type InternalType = Double
+ @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] }
+ private[sql] val numeric = implicitly[Numeric[Double]]
+ private[sql] val fractional = implicitly[Fractional[Double]]
+ private[sql] val ordering = implicitly[Ordering[InternalType]]
+ private[sql] val asIntegral = DoubleAsIfIntegral
+
+ /**
+ * The default size of a value of the DoubleType is 8 bytes.
+ */
+ override def defaultSize: Int = 8
+
+ private[spark] override def asNullable: DoubleType = this
+}
+
+case object DoubleType extends DoubleType
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala
new file mode 100644
index 0000000000000..1d5a2f4f6f86c
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.types
+
+import scala.math.Numeric.FloatAsIfIntegral
+import scala.math.{Ordering, Fractional, Numeric}
+import scala.reflect.runtime.universe.typeTag
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.sql.catalyst.ScalaReflectionLock
+
+/**
+ * :: DeveloperApi ::
+ * The data type representing `Float` values. Please use the singleton [[DataTypes.FloatType]].
+ *
+ * @group dataType
+ */
+@DeveloperApi
+class FloatType private() extends FractionalType {
+ // The companion object and this class is separated so the companion object also subclasses
+ // this type. Otherwise, the companion object would be of type "FloatType$" in byte code.
+ // Defined with a private constructor so the companion object is the only possible instantiation.
+ private[sql] type InternalType = Float
+ @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] }
+ private[sql] val numeric = implicitly[Numeric[Float]]
+ private[sql] val fractional = implicitly[Fractional[Float]]
+ private[sql] val ordering = implicitly[Ordering[InternalType]]
+ private[sql] val asIntegral = FloatAsIfIntegral
+
+ /**
+ * The default size of a value of the FloatType is 4 bytes.
+ */
+ override def defaultSize: Int = 4
+
+ private[spark] override def asNullable: FloatType = this
+}
+
+case object FloatType extends FloatType
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntegerType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntegerType.scala
new file mode 100644
index 0000000000000..74e464c082873
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntegerType.scala
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.types
+
+import scala.math.{Ordering, Integral, Numeric}
+import scala.reflect.runtime.universe.typeTag
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.sql.catalyst.ScalaReflectionLock
+
+
+/**
+ * :: DeveloperApi ::
+ * The data type representing `Int` values. Please use the singleton [[DataTypes.IntegerType]].
+ *
+ * @group dataType
+ */
+@DeveloperApi
+class IntegerType private() extends IntegralType {
+ // The companion object and this class is separated so the companion object also subclasses
+ // this type. Otherwise, the companion object would be of type "IntegerType$" in byte code.
+ // Defined with a private constructor so the companion object is the only possible instantiation.
+ private[sql] type InternalType = Int
+ @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] }
+ private[sql] val numeric = implicitly[Numeric[Int]]
+ private[sql] val integral = implicitly[Integral[Int]]
+ private[sql] val ordering = implicitly[Ordering[InternalType]]
+
+ /**
+ * The default size of a value of the IntegerType is 4 bytes.
+ */
+ override def defaultSize: Int = 4
+
+ override def simpleString: String = "int"
+
+ private[spark] override def asNullable: IntegerType = this
+}
+
+case object IntegerType extends IntegerType
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/LongType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/LongType.scala
new file mode 100644
index 0000000000000..390675782e5fd
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/LongType.scala
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.types
+
+import scala.math.{Ordering, Integral, Numeric}
+import scala.reflect.runtime.universe.typeTag
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.sql.catalyst.ScalaReflectionLock
+
+/**
+ * :: DeveloperApi ::
+ * The data type representing `Long` values. Please use the singleton [[DataTypes.LongType]].
+ *
+ * @group dataType
+ */
+@DeveloperApi
+class LongType private() extends IntegralType {
+ // The companion object and this class is separated so the companion object also subclasses
+ // this type. Otherwise, the companion object would be of type "LongType$" in byte code.
+ // Defined with a private constructor so the companion object is the only possible instantiation.
+ private[sql] type InternalType = Long
+ @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] }
+ private[sql] val numeric = implicitly[Numeric[Long]]
+ private[sql] val integral = implicitly[Integral[Long]]
+ private[sql] val ordering = implicitly[Ordering[InternalType]]
+
+ /**
+ * The default size of a value of the LongType is 8 bytes.
+ */
+ override def defaultSize: Int = 8
+
+ override def simpleString: String = "bigint"
+
+ private[spark] override def asNullable: LongType = this
+}
+
+
+case object LongType extends LongType
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala
new file mode 100644
index 0000000000000..cfdf493074415
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.types
+
+import org.json4s.JsonAST.JValue
+import org.json4s.JsonDSL._
+
+
+/**
+ * :: DeveloperApi ::
+ * The data type for Maps. Keys in a map are not allowed to have `null` values.
+ *
+ * Please use [[DataTypes.createMapType()]] to create a specific instance.
+ *
+ * @param keyType The data type of map keys.
+ * @param valueType The data type of map values.
+ * @param valueContainsNull Indicates if map values have `null` values.
+ *
+ * @group dataType
+ */
+case class MapType(
+ keyType: DataType,
+ valueType: DataType,
+ valueContainsNull: Boolean) extends DataType {
+
+ /** No-arg constructor for kryo. */
+ def this() = this(null, null, false)
+
+ private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = {
+ builder.append(s"$prefix-- key: ${keyType.typeName}\n")
+ builder.append(s"$prefix-- value: ${valueType.typeName} " +
+ s"(valueContainsNull = $valueContainsNull)\n")
+ DataType.buildFormattedString(keyType, s"$prefix |", builder)
+ DataType.buildFormattedString(valueType, s"$prefix |", builder)
+ }
+
+ override private[sql] def jsonValue: JValue =
+ ("type" -> typeName) ~
+ ("keyType" -> keyType.jsonValue) ~
+ ("valueType" -> valueType.jsonValue) ~
+ ("valueContainsNull" -> valueContainsNull)
+
+ /**
+ * The default size of a value of the MapType is
+ * 100 * (the default size of the key type + the default size of the value type).
+ * (We assume that there are 100 elements).
+ */
+ override def defaultSize: Int = 100 * (keyType.defaultSize + valueType.defaultSize)
+
+ override def simpleString: String = s"map<${keyType.simpleString},${valueType.simpleString}>"
+
+ private[spark] override def asNullable: MapType =
+ MapType(keyType.asNullable, valueType.asNullable, valueContainsNull = true)
+}
+
+
+object MapType {
+ /**
+ * Construct a [[MapType]] object with the given key type and value type.
+ * The `valueContainsNull` is true.
+ */
+ def apply(keyType: DataType, valueType: DataType): MapType =
+ MapType(keyType: DataType, valueType: DataType, valueContainsNull = true)
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/NullType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/NullType.scala
new file mode 100644
index 0000000000000..b64b07431fa96
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/NullType.scala
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.types
+
+import org.apache.spark.annotation.DeveloperApi
+
+
+/**
+ * :: DeveloperApi ::
+ * The data type representing `NULL` values. Please use the singleton [[DataTypes.NullType]].
+ *
+ * @group dataType
+ */
+@DeveloperApi
+class NullType private() extends DataType {
+ // The companion object and this class is separated so the companion object also subclasses
+ // this type. Otherwise, the companion object would be of type "NullType$" in byte code.
+ // Defined with a private constructor so the companion object is the only possible instantiation.
+ override def defaultSize: Int = 1
+
+ private[spark] override def asNullable: NullType = this
+}
+
+case object NullType extends NullType
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ShortType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ShortType.scala
new file mode 100644
index 0000000000000..73e9ec780b0af
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ShortType.scala
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.types
+
+import scala.math.{Ordering, Integral, Numeric}
+import scala.reflect.runtime.universe.typeTag
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.sql.catalyst.ScalaReflectionLock
+
+/**
+ * :: DeveloperApi ::
+ * The data type representing `Short` values. Please use the singleton [[DataTypes.ShortType]].
+ *
+ * @group dataType
+ */
+@DeveloperApi
+class ShortType private() extends IntegralType {
+ // The companion object and this class is separated so the companion object also subclasses
+ // this type. Otherwise, the companion object would be of type "ShortType$" in byte code.
+ // Defined with a private constructor so the companion object is the only possible instantiation.
+ private[sql] type InternalType = Short
+ @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] }
+ private[sql] val numeric = implicitly[Numeric[Short]]
+ private[sql] val integral = implicitly[Integral[Short]]
+ private[sql] val ordering = implicitly[Ordering[InternalType]]
+
+ /**
+ * The default size of a value of the ShortType is 2 bytes.
+ */
+ override def defaultSize: Int = 2
+
+ override def simpleString: String = "smallint"
+
+ private[spark] override def asNullable: ShortType = this
+}
+
+case object ShortType extends ShortType
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala
new file mode 100644
index 0000000000000..134ab0af4e0de
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.types
+
+import scala.math.Ordering
+import scala.reflect.runtime.universe.typeTag
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.sql.catalyst.ScalaReflectionLock
+
+/**
+ * :: DeveloperApi ::
+ * The data type representing `String` values. Please use the singleton [[DataTypes.StringType]].
+ *
+ * @group dataType
+ */
+@DeveloperApi
+class StringType private() extends AtomicType {
+ // The companion object and this class is separated so the companion object also subclasses
+ // this type. Otherwise, the companion object would be of type "StringType$" in byte code.
+ // Defined with a private constructor so the companion object is the only possible instantiation.
+ private[sql] type InternalType = UTF8String
+ @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] }
+ private[sql] val ordering = implicitly[Ordering[InternalType]]
+
+ /**
+ * The default size of a value of the StringType is 4096 bytes.
+ */
+ override def defaultSize: Int = 4096
+
+ private[spark] override def asNullable: StringType = this
+}
+
+case object StringType extends StringType
+
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructField.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructField.scala
new file mode 100644
index 0000000000000..83570a5eaee61
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructField.scala
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.types
+
+import org.json4s.JsonAST.JValue
+import org.json4s.JsonDSL._
+
+/**
+ * A field inside a StructType.
+ * @param name The name of this field.
+ * @param dataType The data type of this field.
+ * @param nullable Indicates if values of this field can be `null` values.
+ * @param metadata The metadata of this field. The metadata should be preserved during
+ * transformation if the content of the column is not modified, e.g, in selection.
+ */
+case class StructField(
+ name: String,
+ dataType: DataType,
+ nullable: Boolean = true,
+ metadata: Metadata = Metadata.empty) {
+
+ /** No-arg constructor for kryo. */
+ protected def this() = this(null, null)
+
+ private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = {
+ builder.append(s"$prefix-- $name: ${dataType.typeName} (nullable = $nullable)\n")
+ DataType.buildFormattedString(dataType, s"$prefix |", builder)
+ }
+
+ // override the default toString to be compatible with legacy parquet files.
+ override def toString: String = s"StructField($name,$dataType,$nullable)"
+
+ private[sql] def jsonValue: JValue = {
+ ("name" -> name) ~
+ ("type" -> dataType.jsonValue) ~
+ ("nullable" -> nullable) ~
+ ("metadata" -> metadata.jsonValue)
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
new file mode 100644
index 0000000000000..d80ffca18ec9a
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
@@ -0,0 +1,263 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.types
+
+import scala.collection.mutable.ArrayBuffer
+import scala.math.max
+
+import org.json4s.JsonDSL._
+
+import org.apache.spark.SparkException
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Attribute}
+
+
+/**
+ * :: DeveloperApi ::
+ * A [[StructType]] object can be constructed by
+ * {{{
+ * StructType(fields: Seq[StructField])
+ * }}}
+ * For a [[StructType]] object, one or multiple [[StructField]]s can be extracted by names.
+ * If multiple [[StructField]]s are extracted, a [[StructType]] object will be returned.
+ * If a provided name does not have a matching field, it will be ignored. For the case
+ * of extracting a single StructField, a `null` will be returned.
+ * Example:
+ * {{{
+ * import org.apache.spark.sql._
+ *
+ * val struct =
+ * StructType(
+ * StructField("a", IntegerType, true) ::
+ * StructField("b", LongType, false) ::
+ * StructField("c", BooleanType, false) :: Nil)
+ *
+ * // Extract a single StructField.
+ * val singleField = struct("b")
+ * // singleField: StructField = StructField(b,LongType,false)
+ *
+ * // This struct does not have a field called "d". null will be returned.
+ * val nonExisting = struct("d")
+ * // nonExisting: StructField = null
+ *
+ * // Extract multiple StructFields. Field names are provided in a set.
+ * // A StructType object will be returned.
+ * val twoFields = struct(Set("b", "c"))
+ * // twoFields: StructType =
+ * // StructType(List(StructField(b,LongType,false), StructField(c,BooleanType,false)))
+ *
+ * // Any names without matching fields will be ignored.
+ * // For the case shown below, "d" will be ignored and
+ * // it is treated as struct(Set("b", "c")).
+ * val ignoreNonExisting = struct(Set("b", "c", "d"))
+ * // ignoreNonExisting: StructType =
+ * // StructType(List(StructField(b,LongType,false), StructField(c,BooleanType,false)))
+ * }}}
+ *
+ * A [[org.apache.spark.sql.Row]] object is used as a value of the StructType.
+ * Example:
+ * {{{
+ * import org.apache.spark.sql._
+ *
+ * val innerStruct =
+ * StructType(
+ * StructField("f1", IntegerType, true) ::
+ * StructField("f2", LongType, false) ::
+ * StructField("f3", BooleanType, false) :: Nil)
+ *
+ * val struct = StructType(
+ * StructField("a", innerStruct, true) :: Nil)
+ *
+ * // Create a Row with the schema defined by struct
+ * val row = Row(Row(1, 2, true))
+ * // row: Row = [[1,2,true]]
+ * }}}
+ *
+ * @group dataType
+ */
+@DeveloperApi
+case class StructType(fields: Array[StructField]) extends DataType with Seq[StructField] {
+
+ /** No-arg constructor for kryo. */
+ protected def this() = this(null)
+
+ /** Returns all field names in an array. */
+ def fieldNames: Array[String] = fields.map(_.name)
+
+ private lazy val fieldNamesSet: Set[String] = fieldNames.toSet
+ private lazy val nameToField: Map[String, StructField] = fields.map(f => f.name -> f).toMap
+ private lazy val nameToIndex: Map[String, Int] = fieldNames.zipWithIndex.toMap
+
+ /**
+ * Extracts a [[StructField]] of the given name. If the [[StructType]] object does not
+ * have a name matching the given name, `null` will be returned.
+ */
+ def apply(name: String): StructField = {
+ nameToField.getOrElse(name,
+ throw new IllegalArgumentException(s"""Field "$name" does not exist."""))
+ }
+
+ /**
+ * Returns a [[StructType]] containing [[StructField]]s of the given names, preserving the
+ * original order of fields. Those names which do not have matching fields will be ignored.
+ */
+ def apply(names: Set[String]): StructType = {
+ val nonExistFields = names -- fieldNamesSet
+ if (nonExistFields.nonEmpty) {
+ throw new IllegalArgumentException(
+ s"Field ${nonExistFields.mkString(",")} does not exist.")
+ }
+ // Preserve the original order of fields.
+ StructType(fields.filter(f => names.contains(f.name)))
+ }
+
+ /**
+ * Returns index of a given field
+ */
+ def fieldIndex(name: String): Int = {
+ nameToIndex.getOrElse(name,
+ throw new IllegalArgumentException(s"""Field "$name" does not exist."""))
+ }
+
+ protected[sql] def toAttributes: Seq[AttributeReference] =
+ map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)())
+
+ def treeString: String = {
+ val builder = new StringBuilder
+ builder.append("root\n")
+ val prefix = " |"
+ fields.foreach(field => field.buildFormattedString(prefix, builder))
+
+ builder.toString()
+ }
+
+ def printTreeString(): Unit = println(treeString)
+
+ private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = {
+ fields.foreach(field => field.buildFormattedString(prefix, builder))
+ }
+
+ override private[sql] def jsonValue =
+ ("type" -> typeName) ~
+ ("fields" -> map(_.jsonValue))
+
+ override def apply(fieldIndex: Int): StructField = fields(fieldIndex)
+
+ override def length: Int = fields.length
+
+ override def iterator: Iterator[StructField] = fields.iterator
+
+ /**
+ * The default size of a value of the StructType is the total default sizes of all field types.
+ */
+ override def defaultSize: Int = fields.map(_.dataType.defaultSize).sum
+
+ override def simpleString: String = {
+ val fieldTypes = fields.map(field => s"${field.name}:${field.dataType.simpleString}")
+ s"struct<${fieldTypes.mkString(",")}>"
+ }
+
+ /**
+ * Merges with another schema (`StructType`). For a struct field A from `this` and a struct field
+ * B from `that`,
+ *
+ * 1. If A and B have the same name and data type, they are merged to a field C with the same name
+ * and data type. C is nullable if and only if either A or B is nullable.
+ * 2. If A doesn't exist in `that`, it's included in the result schema.
+ * 3. If B doesn't exist in `this`, it's also included in the result schema.
+ * 4. Otherwise, `this` and `that` are considered as conflicting schemas and an exception would be
+ * thrown.
+ */
+ private[sql] def merge(that: StructType): StructType =
+ StructType.merge(this, that).asInstanceOf[StructType]
+
+ private[spark] override def asNullable: StructType = {
+ val newFields = fields.map {
+ case StructField(name, dataType, nullable, metadata) =>
+ StructField(name, dataType.asNullable, nullable = true, metadata)
+ }
+
+ StructType(newFields)
+ }
+}
+
+
+object StructType {
+
+ def apply(fields: Seq[StructField]): StructType = StructType(fields.toArray)
+
+ def apply(fields: java.util.List[StructField]): StructType = {
+ StructType(fields.toArray.asInstanceOf[Array[StructField]])
+ }
+
+ protected[sql] def fromAttributes(attributes: Seq[Attribute]): StructType =
+ StructType(attributes.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata)))
+
+ private[sql] def merge(left: DataType, right: DataType): DataType =
+ (left, right) match {
+ case (ArrayType(leftElementType, leftContainsNull),
+ ArrayType(rightElementType, rightContainsNull)) =>
+ ArrayType(
+ merge(leftElementType, rightElementType),
+ leftContainsNull || rightContainsNull)
+
+ case (MapType(leftKeyType, leftValueType, leftContainsNull),
+ MapType(rightKeyType, rightValueType, rightContainsNull)) =>
+ MapType(
+ merge(leftKeyType, rightKeyType),
+ merge(leftValueType, rightValueType),
+ leftContainsNull || rightContainsNull)
+
+ case (StructType(leftFields), StructType(rightFields)) =>
+ val newFields = ArrayBuffer.empty[StructField]
+
+ leftFields.foreach {
+ case leftField @ StructField(leftName, leftType, leftNullable, _) =>
+ rightFields
+ .find(_.name == leftName)
+ .map { case rightField @ StructField(_, rightType, rightNullable, _) =>
+ leftField.copy(
+ dataType = merge(leftType, rightType),
+ nullable = leftNullable || rightNullable)
+ }
+ .orElse(Some(leftField))
+ .foreach(newFields += _)
+ }
+
+ rightFields
+ .filterNot(f => leftFields.map(_.name).contains(f.name))
+ .foreach(newFields += _)
+
+ StructType(newFields)
+
+ case (DecimalType.Fixed(leftPrecision, leftScale),
+ DecimalType.Fixed(rightPrecision, rightScale)) =>
+ DecimalType(
+ max(leftScale, rightScale) + max(leftPrecision - leftScale, rightPrecision - rightScale),
+ max(leftScale, rightScale))
+
+ case (leftUdt: UserDefinedType[_], rightUdt: UserDefinedType[_])
+ if leftUdt.userClass == rightUdt.userClass => leftUdt
+
+ case (leftType, rightType) if leftType == rightType =>
+ leftType
+
+ case _ =>
+ throw new SparkException(s"Failed to merge incompatible data types $left and $right")
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/TimestampType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/TimestampType.scala
new file mode 100644
index 0000000000000..aebabfc475925
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/TimestampType.scala
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.types
+
+import java.sql.Timestamp
+
+import scala.math.Ordering
+import scala.reflect.runtime.universe.typeTag
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.sql.catalyst.ScalaReflectionLock
+
+
+/**
+ * :: DeveloperApi ::
+ * The data type representing `java.sql.Timestamp` values.
+ * Please use the singleton [[DataTypes.TimestampType]].
+ *
+ * @group dataType
+ */
+@DeveloperApi
+class TimestampType private() extends AtomicType {
+ // The companion object and this class is separated so the companion object also subclasses
+ // this type. Otherwise, the companion object would be of type "TimestampType$" in byte code.
+ // Defined with a private constructor so the companion object is the only possible instantiation.
+ private[sql] type InternalType = Timestamp
+
+ @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] }
+
+ private[sql] val ordering = new Ordering[InternalType] {
+ def compare(x: Timestamp, y: Timestamp): Int = x.compareTo(y)
+ }
+
+ /**
+ * The default size of a value of the TimestampType is 12 bytes.
+ */
+ override def defaultSize: Int = 12
+
+ private[spark] override def asNullable: TimestampType = this
+}
+
+case object TimestampType extends TimestampType
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
new file mode 100644
index 0000000000000..6b20505c6009a
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.types
+
+import org.json4s.JsonAST.JValue
+import org.json4s.JsonDSL._
+
+import org.apache.spark.annotation.DeveloperApi
+
+/**
+ * ::DeveloperApi::
+ * The data type for User Defined Types (UDTs).
+ *
+ * This interface allows a user to make their own classes more interoperable with SparkSQL;
+ * e.g., by creating a [[UserDefinedType]] for a class X, it becomes possible to create
+ * a `DataFrame` which has class X in the schema.
+ *
+ * For SparkSQL to recognize UDTs, the UDT must be annotated with
+ * [[SQLUserDefinedType]].
+ *
+ * The conversion via `serialize` occurs when instantiating a `DataFrame` from another RDD.
+ * The conversion via `deserialize` occurs when reading from a `DataFrame`.
+ */
+@DeveloperApi
+abstract class UserDefinedType[UserType] extends DataType with Serializable {
+
+ /** Underlying storage type for this UDT */
+ def sqlType: DataType
+
+ /** Paired Python UDT class, if exists. */
+ def pyUDT: String = null
+
+ /**
+ * Convert the user type to a SQL datum
+ *
+ * TODO: Can we make this take obj: UserType? The issue is in
+ * CatalystTypeConverters.convertToCatalyst, where we need to convert Any to UserType.
+ */
+ def serialize(obj: Any): Any
+
+ /** Convert a SQL datum to the user type */
+ def deserialize(datum: Any): UserType
+
+ override private[sql] def jsonValue: JValue = {
+ ("type" -> "udt") ~
+ ("class" -> this.getClass.getName) ~
+ ("pyClass" -> pyUDT) ~
+ ("sqlType" -> sqlType.jsonValue)
+ }
+
+ /**
+ * Class object for the UserType
+ */
+ def userClass: java.lang.Class[UserType]
+
+ /**
+ * The default size of a value of the UserDefinedType is 4096 bytes.
+ */
+ override def defaultSize: Int = 4096
+
+ /**
+ * For UDT, asNullable will not change the nullability of its internal sqlType and just returns
+ * itself.
+ */
+ private[spark] override def asNullable: UserDefinedType[UserType] = this
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala
deleted file mode 100644
index 7cd7bd1914c95..0000000000000
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala
+++ /dev/null
@@ -1,1238 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.types
-
-import java.sql.Timestamp
-
-import scala.collection.mutable.ArrayBuffer
-import scala.math._
-import scala.math.Numeric.{FloatAsIfIntegral, DoubleAsIfIntegral}
-import scala.reflect.ClassTag
-import scala.reflect.runtime.universe.{TypeTag, runtimeMirror, typeTag}
-import scala.util.parsing.combinator.RegexParsers
-
-import org.json4s._
-import org.json4s.JsonAST.JValue
-import org.json4s.JsonDSL._
-import org.json4s.jackson.JsonMethods._
-
-import org.apache.spark.SparkException
-import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.sql.catalyst.ScalaReflectionLock
-import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression}
-import org.apache.spark.util.Utils
-
-
-object DataType {
- def fromJson(json: String): DataType = parseDataType(parse(json))
-
- private object JSortedObject {
- def unapplySeq(value: JValue): Option[List[(String, JValue)]] = value match {
- case JObject(seq) => Some(seq.toList.sortBy(_._1))
- case _ => None
- }
- }
-
- // NOTE: Map fields must be sorted in alphabetical order to keep consistent with the Python side.
- private def parseDataType(json: JValue): DataType = json match {
- case JString(name) =>
- PrimitiveType.nameToType(name)
-
- case JSortedObject(
- ("containsNull", JBool(n)),
- ("elementType", t: JValue),
- ("type", JString("array"))) =>
- ArrayType(parseDataType(t), n)
-
- case JSortedObject(
- ("keyType", k: JValue),
- ("type", JString("map")),
- ("valueContainsNull", JBool(n)),
- ("valueType", v: JValue)) =>
- MapType(parseDataType(k), parseDataType(v), n)
-
- case JSortedObject(
- ("fields", JArray(fields)),
- ("type", JString("struct"))) =>
- StructType(fields.map(parseStructField))
-
- case JSortedObject(
- ("class", JString(udtClass)),
- ("pyClass", _),
- ("sqlType", _),
- ("type", JString("udt"))) =>
- Class.forName(udtClass).newInstance().asInstanceOf[UserDefinedType[_]]
- }
-
- private def parseStructField(json: JValue): StructField = json match {
- case JSortedObject(
- ("metadata", metadata: JObject),
- ("name", JString(name)),
- ("nullable", JBool(nullable)),
- ("type", dataType: JValue)) =>
- StructField(name, parseDataType(dataType), nullable, Metadata.fromJObject(metadata))
- // Support reading schema when 'metadata' is missing.
- case JSortedObject(
- ("name", JString(name)),
- ("nullable", JBool(nullable)),
- ("type", dataType: JValue)) =>
- StructField(name, parseDataType(dataType), nullable)
- }
-
- @deprecated("Use DataType.fromJson instead", "1.2.0")
- def fromCaseClassString(string: String): DataType = CaseClassStringParser(string)
-
- private object CaseClassStringParser extends RegexParsers {
- protected lazy val primitiveType: Parser[DataType] =
- ( "StringType" ^^^ StringType
- | "FloatType" ^^^ FloatType
- | "IntegerType" ^^^ IntegerType
- | "ByteType" ^^^ ByteType
- | "ShortType" ^^^ ShortType
- | "DoubleType" ^^^ DoubleType
- | "LongType" ^^^ LongType
- | "BinaryType" ^^^ BinaryType
- | "BooleanType" ^^^ BooleanType
- | "DateType" ^^^ DateType
- | "DecimalType()" ^^^ DecimalType.Unlimited
- | fixedDecimalType
- | "TimestampType" ^^^ TimestampType
- )
-
- protected lazy val fixedDecimalType: Parser[DataType] =
- ("DecimalType(" ~> "[0-9]+".r) ~ ("," ~> "[0-9]+".r <~ ")") ^^ {
- case precision ~ scale => DecimalType(precision.toInt, scale.toInt)
- }
-
- protected lazy val arrayType: Parser[DataType] =
- "ArrayType" ~> "(" ~> dataType ~ "," ~ boolVal <~ ")" ^^ {
- case tpe ~ _ ~ containsNull => ArrayType(tpe, containsNull)
- }
-
- protected lazy val mapType: Parser[DataType] =
- "MapType" ~> "(" ~> dataType ~ "," ~ dataType ~ "," ~ boolVal <~ ")" ^^ {
- case t1 ~ _ ~ t2 ~ _ ~ valueContainsNull => MapType(t1, t2, valueContainsNull)
- }
-
- protected lazy val structField: Parser[StructField] =
- ("StructField(" ~> "[a-zA-Z0-9_]*".r) ~ ("," ~> dataType) ~ ("," ~> boolVal <~ ")") ^^ {
- case name ~ tpe ~ nullable =>
- StructField(name, tpe, nullable = nullable)
- }
-
- protected lazy val boolVal: Parser[Boolean] =
- ( "true" ^^^ true
- | "false" ^^^ false
- )
-
- protected lazy val structType: Parser[DataType] =
- "StructType\\([A-zA-z]*\\(".r ~> repsep(structField, ",") <~ "))" ^^ {
- case fields => StructType(fields)
- }
-
- protected lazy val dataType: Parser[DataType] =
- ( arrayType
- | mapType
- | structType
- | primitiveType
- )
-
- /**
- * Parses a string representation of a DataType.
- *
- * TODO: Generate parser as pickler...
- */
- def apply(asString: String): DataType = parseAll(dataType, asString) match {
- case Success(result, _) => result
- case failure: NoSuccess =>
- throw new IllegalArgumentException(s"Unsupported dataType: $asString, $failure")
- }
- }
-
- protected[types] def buildFormattedString(
- dataType: DataType,
- prefix: String,
- builder: StringBuilder): Unit = {
- dataType match {
- case array: ArrayType =>
- array.buildFormattedString(prefix, builder)
- case struct: StructType =>
- struct.buildFormattedString(prefix, builder)
- case map: MapType =>
- map.buildFormattedString(prefix, builder)
- case _ =>
- }
- }
-
- /**
- * Compares two types, ignoring nullability of ArrayType, MapType, StructType.
- */
- private[types] def equalsIgnoreNullability(left: DataType, right: DataType): Boolean = {
- (left, right) match {
- case (ArrayType(leftElementType, _), ArrayType(rightElementType, _)) =>
- equalsIgnoreNullability(leftElementType, rightElementType)
- case (MapType(leftKeyType, leftValueType, _), MapType(rightKeyType, rightValueType, _)) =>
- equalsIgnoreNullability(leftKeyType, rightKeyType) &&
- equalsIgnoreNullability(leftValueType, rightValueType)
- case (StructType(leftFields), StructType(rightFields)) =>
- leftFields.size == rightFields.size &&
- leftFields.zip(rightFields)
- .forall{
- case (left, right) =>
- left.name == right.name && equalsIgnoreNullability(left.dataType, right.dataType)
- }
- case (left, right) => left == right
- }
- }
-
- /**
- * Compares two types, ignoring compatible nullability of ArrayType, MapType, StructType.
- *
- * Compatible nullability is defined as follows:
- * - If `from` and `to` are ArrayTypes, `from` has a compatible nullability with `to`
- * if and only if `to.containsNull` is true, or both of `from.containsNull` and
- * `to.containsNull` are false.
- * - If `from` and `to` are MapTypes, `from` has a compatible nullability with `to`
- * if and only if `to.valueContainsNull` is true, or both of `from.valueContainsNull` and
- * `to.valueContainsNull` are false.
- * - If `from` and `to` are StructTypes, `from` has a compatible nullability with `to`
- * if and only if for all every pair of fields, `to.nullable` is true, or both
- * of `fromField.nullable` and `toField.nullable` are false.
- */
- private[sql] def equalsIgnoreCompatibleNullability(from: DataType, to: DataType): Boolean = {
- (from, to) match {
- case (ArrayType(fromElement, fn), ArrayType(toElement, tn)) =>
- (tn || !fn) && equalsIgnoreCompatibleNullability(fromElement, toElement)
-
- case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) =>
- (tn || !fn) &&
- equalsIgnoreCompatibleNullability(fromKey, toKey) &&
- equalsIgnoreCompatibleNullability(fromValue, toValue)
-
- case (StructType(fromFields), StructType(toFields)) =>
- fromFields.size == toFields.size &&
- fromFields.zip(toFields).forall {
- case (fromField, toField) =>
- fromField.name == toField.name &&
- (toField.nullable || !fromField.nullable) &&
- equalsIgnoreCompatibleNullability(fromField.dataType, toField.dataType)
- }
-
- case (fromDataType, toDataType) => fromDataType == toDataType
- }
- }
-}
-
-
-/**
- * :: DeveloperApi ::
- * The base type of all Spark SQL data types.
- *
- * @group dataType
- */
-@DeveloperApi
-abstract class DataType {
- /** Matches any expression that evaluates to this DataType */
- def unapply(a: Expression): Boolean = a match {
- case e: Expression if e.dataType == this => true
- case _ => false
- }
-
- /** The default size of a value of this data type. */
- def defaultSize: Int
-
- def isPrimitive: Boolean = false
-
- def typeName: String = this.getClass.getSimpleName.stripSuffix("$").dropRight(4).toLowerCase
-
- private[sql] def jsonValue: JValue = typeName
-
- def json: String = compact(render(jsonValue))
-
- def prettyJson: String = pretty(render(jsonValue))
-
- def simpleString: String = typeName
-
- /** Check if `this` and `other` are the same data type when ignoring nullability
- * (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`).
- */
- private[spark] def sameType(other: DataType): Boolean =
- DataType.equalsIgnoreNullability(this, other)
-
- /** Returns the same data type but set all nullability fields are true
- * (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`).
- */
- private[spark] def asNullable: DataType
-}
-
-/**
- * :: DeveloperApi ::
- * The data type representing `NULL` values. Please use the singleton [[DataTypes.NullType]].
- *
- * @group dataType
- */
-@DeveloperApi
-class NullType private() extends DataType {
- // The companion object and this class is separated so the companion object also subclasses
- // this type. Otherwise, the companion object would be of type "NullType$" in byte code.
- // Defined with a private constructor so the companion object is the only possible instantiation.
- override def defaultSize: Int = 1
-
- private[spark] override def asNullable: NullType = this
-}
-
-case object NullType extends NullType
-
-
-protected[spark] object NativeType {
- val all = Seq(
- IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType)
-
- def unapply(dt: DataType): Boolean = all.contains(dt)
-}
-
-
-protected[sql] trait PrimitiveType extends DataType {
- override def isPrimitive: Boolean = true
-}
-
-
-protected[sql] object PrimitiveType {
- private val nonDecimals = Seq(NullType, DateType, TimestampType, BinaryType) ++ NativeType.all
- private val nonDecimalNameToType = nonDecimals.map(t => t.typeName -> t).toMap
-
- /** Given the string representation of a type, return its DataType */
- private[sql] def nameToType(name: String): DataType = {
- val FIXED_DECIMAL = """decimal\(\s*(\d+)\s*,\s*(\d+)\s*\)""".r
- name match {
- case "decimal" => DecimalType.Unlimited
- case FIXED_DECIMAL(precision, scale) => DecimalType(precision.toInt, scale.toInt)
- case other => nonDecimalNameToType(other)
- }
- }
-}
-
-protected[spark] abstract class NativeType extends DataType {
- private[sql] type JvmType
- @transient private[sql] val tag: TypeTag[JvmType]
- private[sql] val ordering: Ordering[JvmType]
-
- @transient private[sql] val classTag = ScalaReflectionLock.synchronized {
- val mirror = runtimeMirror(Utils.getSparkClassLoader)
- ClassTag[JvmType](mirror.runtimeClass(tag.tpe))
- }
-}
-
-
-/**
- * :: DeveloperApi ::
- * The data type representing `String` values. Please use the singleton [[DataTypes.StringType]].
- *
- * @group dataType
- */
-@DeveloperApi
-class StringType private() extends NativeType with PrimitiveType {
- // The companion object and this class is separated so the companion object also subclasses
- // this type. Otherwise, the companion object would be of type "StringType$" in byte code.
- // Defined with a private constructor so the companion object is the only possible instantiation.
- private[sql] type JvmType = UTF8String
- @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] }
- private[sql] val ordering = implicitly[Ordering[JvmType]]
-
- /**
- * The default size of a value of the StringType is 4096 bytes.
- */
- override def defaultSize: Int = 4096
-
- private[spark] override def asNullable: StringType = this
-}
-
-case object StringType extends StringType
-
-
-/**
- * :: DeveloperApi ::
- * The data type representing `Array[Byte]` values.
- * Please use the singleton [[DataTypes.BinaryType]].
- *
- * @group dataType
- */
-@DeveloperApi
-class BinaryType private() extends NativeType with PrimitiveType {
- // The companion object and this class is separated so the companion object also subclasses
- // this type. Otherwise, the companion object would be of type "BinaryType$" in byte code.
- // Defined with a private constructor so the companion object is the only possible instantiation.
- private[sql] type JvmType = Array[Byte]
- @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] }
- private[sql] val ordering = new Ordering[JvmType] {
- def compare(x: Array[Byte], y: Array[Byte]): Int = {
- for (i <- 0 until x.length; if i < y.length) {
- val res = x(i).compareTo(y(i))
- if (res != 0) return res
- }
- x.length - y.length
- }
- }
-
- /**
- * The default size of a value of the BinaryType is 4096 bytes.
- */
- override def defaultSize: Int = 4096
-
- private[spark] override def asNullable: BinaryType = this
-}
-
-case object BinaryType extends BinaryType
-
-
-/**
- * :: DeveloperApi ::
- * The data type representing `Boolean` values. Please use the singleton [[DataTypes.BooleanType]].
- *
- *@group dataType
- */
-@DeveloperApi
-class BooleanType private() extends NativeType with PrimitiveType {
- // The companion object and this class is separated so the companion object also subclasses
- // this type. Otherwise, the companion object would be of type "BooleanType$" in byte code.
- // Defined with a private constructor so the companion object is the only possible instantiation.
- private[sql] type JvmType = Boolean
- @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] }
- private[sql] val ordering = implicitly[Ordering[JvmType]]
-
- /**
- * The default size of a value of the BooleanType is 1 byte.
- */
- override def defaultSize: Int = 1
-
- private[spark] override def asNullable: BooleanType = this
-}
-
-case object BooleanType extends BooleanType
-
-
-/**
- * :: DeveloperApi ::
- * The data type representing `java.sql.Timestamp` values.
- * Please use the singleton [[DataTypes.TimestampType]].
- *
- * @group dataType
- */
-@DeveloperApi
-class TimestampType private() extends NativeType {
- // The companion object and this class is separated so the companion object also subclasses
- // this type. Otherwise, the companion object would be of type "TimestampType$" in byte code.
- // Defined with a private constructor so the companion object is the only possible instantiation.
- private[sql] type JvmType = Timestamp
-
- @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] }
-
- private[sql] val ordering = new Ordering[JvmType] {
- def compare(x: Timestamp, y: Timestamp): Int = x.compareTo(y)
- }
-
- /**
- * The default size of a value of the TimestampType is 12 bytes.
- */
- override def defaultSize: Int = 12
-
- private[spark] override def asNullable: TimestampType = this
-}
-
-case object TimestampType extends TimestampType
-
-
-/**
- * :: DeveloperApi ::
- * The data type representing `java.sql.Date` values.
- * Please use the singleton [[DataTypes.DateType]].
- *
- * @group dataType
- */
-@DeveloperApi
-class DateType private() extends NativeType {
- // The companion object and this class is separated so the companion object also subclasses
- // this type. Otherwise, the companion object would be of type "DateType$" in byte code.
- // Defined with a private constructor so the companion object is the only possible instantiation.
- private[sql] type JvmType = Int
-
- @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] }
-
- private[sql] val ordering = implicitly[Ordering[JvmType]]
-
- /**
- * The default size of a value of the DateType is 4 bytes.
- */
- override def defaultSize: Int = 4
-
- private[spark] override def asNullable: DateType = this
-}
-
-case object DateType extends DateType
-
-
-/**
- * :: DeveloperApi ::
- * Numeric data types.
- *
- * @group dataType
- */
-abstract class NumericType extends NativeType with PrimitiveType {
- // Unfortunately we can't get this implicitly as that breaks Spark Serialization. In order for
- // implicitly[Numeric[JvmType]] to be valid, we have to change JvmType from a type variable to a
- // type parameter and and add a numeric annotation (i.e., [JvmType : Numeric]). This gets
- // desugared by the compiler into an argument to the objects constructor. This means there is no
- // longer an no argument constructor and thus the JVM cannot serialize the object anymore.
- private[sql] val numeric: Numeric[JvmType]
-}
-
-
-protected[sql] object NumericType {
- def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[NumericType]
-}
-
-
-/** Matcher for any expressions that evaluate to [[IntegralType]]s */
-protected[sql] object IntegralType {
- def unapply(a: Expression): Boolean = a match {
- case e: Expression if e.dataType.isInstanceOf[IntegralType] => true
- case _ => false
- }
-}
-
-
-protected[sql] sealed abstract class IntegralType extends NumericType {
- private[sql] val integral: Integral[JvmType]
-}
-
-
-/**
- * :: DeveloperApi ::
- * The data type representing `Long` values. Please use the singleton [[DataTypes.LongType]].
- *
- * @group dataType
- */
-@DeveloperApi
-class LongType private() extends IntegralType {
- // The companion object and this class is separated so the companion object also subclasses
- // this type. Otherwise, the companion object would be of type "LongType$" in byte code.
- // Defined with a private constructor so the companion object is the only possible instantiation.
- private[sql] type JvmType = Long
- @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] }
- private[sql] val numeric = implicitly[Numeric[Long]]
- private[sql] val integral = implicitly[Integral[Long]]
- private[sql] val ordering = implicitly[Ordering[JvmType]]
-
- /**
- * The default size of a value of the LongType is 8 bytes.
- */
- override def defaultSize: Int = 8
-
- override def simpleString: String = "bigint"
-
- private[spark] override def asNullable: LongType = this
-}
-
-case object LongType extends LongType
-
-
-/**
- * :: DeveloperApi ::
- * The data type representing `Int` values. Please use the singleton [[DataTypes.IntegerType]].
- *
- * @group dataType
- */
-@DeveloperApi
-class IntegerType private() extends IntegralType {
- // The companion object and this class is separated so the companion object also subclasses
- // this type. Otherwise, the companion object would be of type "IntegerType$" in byte code.
- // Defined with a private constructor so the companion object is the only possible instantiation.
- private[sql] type JvmType = Int
- @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] }
- private[sql] val numeric = implicitly[Numeric[Int]]
- private[sql] val integral = implicitly[Integral[Int]]
- private[sql] val ordering = implicitly[Ordering[JvmType]]
-
- /**
- * The default size of a value of the IntegerType is 4 bytes.
- */
- override def defaultSize: Int = 4
-
- override def simpleString: String = "int"
-
- private[spark] override def asNullable: IntegerType = this
-}
-
-case object IntegerType extends IntegerType
-
-
-/**
- * :: DeveloperApi ::
- * The data type representing `Short` values. Please use the singleton [[DataTypes.ShortType]].
- *
- * @group dataType
- */
-@DeveloperApi
-class ShortType private() extends IntegralType {
- // The companion object and this class is separated so the companion object also subclasses
- // this type. Otherwise, the companion object would be of type "ShortType$" in byte code.
- // Defined with a private constructor so the companion object is the only possible instantiation.
- private[sql] type JvmType = Short
- @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] }
- private[sql] val numeric = implicitly[Numeric[Short]]
- private[sql] val integral = implicitly[Integral[Short]]
- private[sql] val ordering = implicitly[Ordering[JvmType]]
-
- /**
- * The default size of a value of the ShortType is 2 bytes.
- */
- override def defaultSize: Int = 2
-
- override def simpleString: String = "smallint"
-
- private[spark] override def asNullable: ShortType = this
-}
-
-case object ShortType extends ShortType
-
-
-/**
- * :: DeveloperApi ::
- * The data type representing `Byte` values. Please use the singleton [[DataTypes.ByteType]].
- *
- * @group dataType
- */
-@DeveloperApi
-class ByteType private() extends IntegralType {
- // The companion object and this class is separated so the companion object also subclasses
- // this type. Otherwise, the companion object would be of type "ByteType$" in byte code.
- // Defined with a private constructor so the companion object is the only possible instantiation.
- private[sql] type JvmType = Byte
- @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] }
- private[sql] val numeric = implicitly[Numeric[Byte]]
- private[sql] val integral = implicitly[Integral[Byte]]
- private[sql] val ordering = implicitly[Ordering[JvmType]]
-
- /**
- * The default size of a value of the ByteType is 1 byte.
- */
- override def defaultSize: Int = 1
-
- override def simpleString: String = "tinyint"
-
- private[spark] override def asNullable: ByteType = this
-}
-
-case object ByteType extends ByteType
-
-
-/** Matcher for any expressions that evaluate to [[FractionalType]]s */
-protected[sql] object FractionalType {
- def unapply(a: Expression): Boolean = a match {
- case e: Expression if e.dataType.isInstanceOf[FractionalType] => true
- case _ => false
- }
-}
-
-
-protected[sql] sealed abstract class FractionalType extends NumericType {
- private[sql] val fractional: Fractional[JvmType]
- private[sql] val asIntegral: Integral[JvmType]
-}
-
-
-/** Precision parameters for a Decimal */
-case class PrecisionInfo(precision: Int, scale: Int)
-
-
-/**
- * :: DeveloperApi ::
- * The data type representing `java.math.BigDecimal` values.
- * A Decimal that might have fixed precision and scale, or unlimited values for these.
- *
- * Please use [[DataTypes.createDecimalType()]] to create a specific instance.
- *
- * @group dataType
- */
-@DeveloperApi
-case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalType {
-
- /** No-arg constructor for kryo. */
- protected def this() = this(null)
-
- private[sql] type JvmType = Decimal
- @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] }
- private[sql] val numeric = Decimal.DecimalIsFractional
- private[sql] val fractional = Decimal.DecimalIsFractional
- private[sql] val ordering = Decimal.DecimalIsFractional
- private[sql] val asIntegral = Decimal.DecimalAsIfIntegral
-
- def precision: Int = precisionInfo.map(_.precision).getOrElse(-1)
-
- def scale: Int = precisionInfo.map(_.scale).getOrElse(-1)
-
- override def typeName: String = precisionInfo match {
- case Some(PrecisionInfo(precision, scale)) => s"decimal($precision,$scale)"
- case None => "decimal"
- }
-
- override def toString: String = precisionInfo match {
- case Some(PrecisionInfo(precision, scale)) => s"DecimalType($precision,$scale)"
- case None => "DecimalType()"
- }
-
- /**
- * The default size of a value of the DecimalType is 4096 bytes.
- */
- override def defaultSize: Int = 4096
-
- override def simpleString: String = precisionInfo match {
- case Some(PrecisionInfo(precision, scale)) => s"decimal($precision,$scale)"
- case None => "decimal(10,0)"
- }
-
- private[spark] override def asNullable: DecimalType = this
-}
-
-
-/** Extra factory methods and pattern matchers for Decimals */
-object DecimalType {
- val Unlimited: DecimalType = DecimalType(None)
-
- object Fixed {
- def unapply(t: DecimalType): Option[(Int, Int)] =
- t.precisionInfo.map(p => (p.precision, p.scale))
- }
-
- object Expression {
- def unapply(e: Expression): Option[(Int, Int)] = e.dataType match {
- case t: DecimalType => t.precisionInfo.map(p => (p.precision, p.scale))
- case _ => None
- }
- }
-
- def apply(): DecimalType = Unlimited
-
- def apply(precision: Int, scale: Int): DecimalType =
- DecimalType(Some(PrecisionInfo(precision, scale)))
-
- def unapply(t: DataType): Boolean = t.isInstanceOf[DecimalType]
-
- def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[DecimalType]
-
- def isFixed(dataType: DataType): Boolean = dataType match {
- case DecimalType.Fixed(_, _) => true
- case _ => false
- }
-}
-
-
-/**
- * :: DeveloperApi ::
- * The data type representing `Double` values. Please use the singleton [[DataTypes.DoubleType]].
- *
- * @group dataType
- */
-@DeveloperApi
-class DoubleType private() extends FractionalType {
- // The companion object and this class is separated so the companion object also subclasses
- // this type. Otherwise, the companion object would be of type "DoubleType$" in byte code.
- // Defined with a private constructor so the companion object is the only possible instantiation.
- private[sql] type JvmType = Double
- @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] }
- private[sql] val numeric = implicitly[Numeric[Double]]
- private[sql] val fractional = implicitly[Fractional[Double]]
- private[sql] val ordering = implicitly[Ordering[JvmType]]
- private[sql] val asIntegral = DoubleAsIfIntegral
-
- /**
- * The default size of a value of the DoubleType is 8 bytes.
- */
- override def defaultSize: Int = 8
-
- private[spark] override def asNullable: DoubleType = this
-}
-
-case object DoubleType extends DoubleType
-
-
-/**
- * :: DeveloperApi ::
- * The data type representing `Float` values. Please use the singleton [[DataTypes.FloatType]].
- *
- * @group dataType
- */
-@DeveloperApi
-class FloatType private() extends FractionalType {
- // The companion object and this class is separated so the companion object also subclasses
- // this type. Otherwise, the companion object would be of type "FloatType$" in byte code.
- // Defined with a private constructor so the companion object is the only possible instantiation.
- private[sql] type JvmType = Float
- @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] }
- private[sql] val numeric = implicitly[Numeric[Float]]
- private[sql] val fractional = implicitly[Fractional[Float]]
- private[sql] val ordering = implicitly[Ordering[JvmType]]
- private[sql] val asIntegral = FloatAsIfIntegral
-
- /**
- * The default size of a value of the FloatType is 4 bytes.
- */
- override def defaultSize: Int = 4
-
- private[spark] override def asNullable: FloatType = this
-}
-
-case object FloatType extends FloatType
-
-
-object ArrayType {
- /** Construct a [[ArrayType]] object with the given element type. The `containsNull` is true. */
- def apply(elementType: DataType): ArrayType = ArrayType(elementType, true)
-}
-
-
-/**
- * :: DeveloperApi ::
- * The data type for collections of multiple values.
- * Internally these are represented as columns that contain a ``scala.collection.Seq``.
- *
- * Please use [[DataTypes.createArrayType()]] to create a specific instance.
- *
- * An [[ArrayType]] object comprises two fields, `elementType: [[DataType]]` and
- * `containsNull: Boolean`. The field of `elementType` is used to specify the type of
- * array elements. The field of `containsNull` is used to specify if the array has `null` values.
- *
- * @param elementType The data type of values.
- * @param containsNull Indicates if values have `null` values
- *
- * @group dataType
- */
-@DeveloperApi
-case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataType {
-
- /** No-arg constructor for kryo. */
- protected def this() = this(null, false)
-
- private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = {
- builder.append(
- s"$prefix-- element: ${elementType.typeName} (containsNull = $containsNull)\n")
- DataType.buildFormattedString(elementType, s"$prefix |", builder)
- }
-
- override private[sql] def jsonValue =
- ("type" -> typeName) ~
- ("elementType" -> elementType.jsonValue) ~
- ("containsNull" -> containsNull)
-
- /**
- * The default size of a value of the ArrayType is 100 * the default size of the element type.
- * (We assume that there are 100 elements).
- */
- override def defaultSize: Int = 100 * elementType.defaultSize
-
- override def simpleString: String = s"array<${elementType.simpleString}>"
-
- private[spark] override def asNullable: ArrayType =
- ArrayType(elementType.asNullable, containsNull = true)
-}
-
-
-/**
- * A field inside a StructType.
- * @param name The name of this field.
- * @param dataType The data type of this field.
- * @param nullable Indicates if values of this field can be `null` values.
- * @param metadata The metadata of this field. The metadata should be preserved during
- * transformation if the content of the column is not modified, e.g, in selection.
- */
-case class StructField(
- name: String,
- dataType: DataType,
- nullable: Boolean = true,
- metadata: Metadata = Metadata.empty) {
-
- /** No-arg constructor for kryo. */
- protected def this() = this(null, null)
-
- private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = {
- builder.append(s"$prefix-- $name: ${dataType.typeName} (nullable = $nullable)\n")
- DataType.buildFormattedString(dataType, s"$prefix |", builder)
- }
-
- // override the default toString to be compatible with legacy parquet files.
- override def toString: String = s"StructField($name,$dataType,$nullable)"
-
- private[sql] def jsonValue: JValue = {
- ("name" -> name) ~
- ("type" -> dataType.jsonValue) ~
- ("nullable" -> nullable) ~
- ("metadata" -> metadata.jsonValue)
- }
-}
-
-
-object StructType {
- protected[sql] def fromAttributes(attributes: Seq[Attribute]): StructType =
- StructType(attributes.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata)))
-
- def apply(fields: Seq[StructField]): StructType = StructType(fields.toArray)
-
- def apply(fields: java.util.List[StructField]): StructType = {
- StructType(fields.toArray.asInstanceOf[Array[StructField]])
- }
-
- private[sql] def merge(left: DataType, right: DataType): DataType =
- (left, right) match {
- case (ArrayType(leftElementType, leftContainsNull),
- ArrayType(rightElementType, rightContainsNull)) =>
- ArrayType(
- merge(leftElementType, rightElementType),
- leftContainsNull || rightContainsNull)
-
- case (MapType(leftKeyType, leftValueType, leftContainsNull),
- MapType(rightKeyType, rightValueType, rightContainsNull)) =>
- MapType(
- merge(leftKeyType, rightKeyType),
- merge(leftValueType, rightValueType),
- leftContainsNull || rightContainsNull)
-
- case (StructType(leftFields), StructType(rightFields)) =>
- val newFields = ArrayBuffer.empty[StructField]
-
- leftFields.foreach {
- case leftField @ StructField(leftName, leftType, leftNullable, _) =>
- rightFields
- .find(_.name == leftName)
- .map { case rightField @ StructField(_, rightType, rightNullable, _) =>
- leftField.copy(
- dataType = merge(leftType, rightType),
- nullable = leftNullable || rightNullable)
- }
- .orElse(Some(leftField))
- .foreach(newFields += _)
- }
-
- rightFields
- .filterNot(f => leftFields.map(_.name).contains(f.name))
- .foreach(newFields += _)
-
- StructType(newFields)
-
- case (DecimalType.Fixed(leftPrecision, leftScale),
- DecimalType.Fixed(rightPrecision, rightScale)) =>
- DecimalType(
- max(leftScale, rightScale) + max(leftPrecision - leftScale, rightPrecision - rightScale),
- max(leftScale, rightScale))
-
- case (leftUdt: UserDefinedType[_], rightUdt: UserDefinedType[_])
- if leftUdt.userClass == rightUdt.userClass => leftUdt
-
- case (leftType, rightType) if leftType == rightType =>
- leftType
-
- case _ =>
- throw new SparkException(s"Failed to merge incompatible data types $left and $right")
- }
-}
-
-
-/**
- * :: DeveloperApi ::
- * A [[StructType]] object can be constructed by
- * {{{
- * StructType(fields: Seq[StructField])
- * }}}
- * For a [[StructType]] object, one or multiple [[StructField]]s can be extracted by names.
- * If multiple [[StructField]]s are extracted, a [[StructType]] object will be returned.
- * If a provided name does not have a matching field, it will be ignored. For the case
- * of extracting a single StructField, a `null` will be returned.
- * Example:
- * {{{
- * import org.apache.spark.sql._
- *
- * val struct =
- * StructType(
- * StructField("a", IntegerType, true) ::
- * StructField("b", LongType, false) ::
- * StructField("c", BooleanType, false) :: Nil)
- *
- * // Extract a single StructField.
- * val singleField = struct("b")
- * // singleField: StructField = StructField(b,LongType,false)
- *
- * // This struct does not have a field called "d". null will be returned.
- * val nonExisting = struct("d")
- * // nonExisting: StructField = null
- *
- * // Extract multiple StructFields. Field names are provided in a set.
- * // A StructType object will be returned.
- * val twoFields = struct(Set("b", "c"))
- * // twoFields: StructType =
- * // StructType(List(StructField(b,LongType,false), StructField(c,BooleanType,false)))
- *
- * // Any names without matching fields will be ignored.
- * // For the case shown below, "d" will be ignored and
- * // it is treated as struct(Set("b", "c")).
- * val ignoreNonExisting = struct(Set("b", "c", "d"))
- * // ignoreNonExisting: StructType =
- * // StructType(List(StructField(b,LongType,false), StructField(c,BooleanType,false)))
- * }}}
- *
- * A [[org.apache.spark.sql.Row]] object is used as a value of the StructType.
- * Example:
- * {{{
- * import org.apache.spark.sql._
- *
- * val innerStruct =
- * StructType(
- * StructField("f1", IntegerType, true) ::
- * StructField("f2", LongType, false) ::
- * StructField("f3", BooleanType, false) :: Nil)
- *
- * val struct = StructType(
- * StructField("a", innerStruct, true) :: Nil)
- *
- * // Create a Row with the schema defined by struct
- * val row = Row(Row(1, 2, true))
- * // row: Row = [[1,2,true]]
- * }}}
- *
- * @group dataType
- */
-@DeveloperApi
-case class StructType(fields: Array[StructField]) extends DataType with Seq[StructField] {
-
- /** No-arg constructor for kryo. */
- protected def this() = this(null)
-
- /** Returns all field names in an array. */
- def fieldNames: Array[String] = fields.map(_.name)
-
- private lazy val fieldNamesSet: Set[String] = fieldNames.toSet
- private lazy val nameToField: Map[String, StructField] = fields.map(f => f.name -> f).toMap
- private lazy val nameToIndex: Map[String, Int] = fieldNames.zipWithIndex.toMap
-
- /**
- * Extracts a [[StructField]] of the given name. If the [[StructType]] object does not
- * have a name matching the given name, `null` will be returned.
- */
- def apply(name: String): StructField = {
- nameToField.getOrElse(name,
- throw new IllegalArgumentException(s"""Field "$name" does not exist."""))
- }
-
- /**
- * Returns a [[StructType]] containing [[StructField]]s of the given names, preserving the
- * original order of fields. Those names which do not have matching fields will be ignored.
- */
- def apply(names: Set[String]): StructType = {
- val nonExistFields = names -- fieldNamesSet
- if (nonExistFields.nonEmpty) {
- throw new IllegalArgumentException(
- s"Field ${nonExistFields.mkString(",")} does not exist.")
- }
- // Preserve the original order of fields.
- StructType(fields.filter(f => names.contains(f.name)))
- }
-
- /**
- * Returns index of a given field
- */
- def fieldIndex(name: String): Int = {
- nameToIndex.getOrElse(name,
- throw new IllegalArgumentException(s"""Field "$name" does not exist."""))
- }
-
- protected[sql] def toAttributes: Seq[AttributeReference] =
- map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)())
-
- def treeString: String = {
- val builder = new StringBuilder
- builder.append("root\n")
- val prefix = " |"
- fields.foreach(field => field.buildFormattedString(prefix, builder))
-
- builder.toString()
- }
-
- def printTreeString(): Unit = println(treeString)
-
- private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = {
- fields.foreach(field => field.buildFormattedString(prefix, builder))
- }
-
- override private[sql] def jsonValue =
- ("type" -> typeName) ~
- ("fields" -> map(_.jsonValue))
-
- override def apply(fieldIndex: Int): StructField = fields(fieldIndex)
-
- override def length: Int = fields.length
-
- override def iterator: Iterator[StructField] = fields.iterator
-
- /**
- * The default size of a value of the StructType is the total default sizes of all field types.
- */
- override def defaultSize: Int = fields.map(_.dataType.defaultSize).sum
-
- override def simpleString: String = {
- val fieldTypes = fields.map(field => s"${field.name}:${field.dataType.simpleString}")
- s"struct<${fieldTypes.mkString(",")}>"
- }
-
- /**
- * Merges with another schema (`StructType`). For a struct field A from `this` and a struct field
- * B from `that`,
- *
- * 1. If A and B have the same name and data type, they are merged to a field C with the same name
- * and data type. C is nullable if and only if either A or B is nullable.
- * 2. If A doesn't exist in `that`, it's included in the result schema.
- * 3. If B doesn't exist in `this`, it's also included in the result schema.
- * 4. Otherwise, `this` and `that` are considered as conflicting schemas and an exception would be
- * thrown.
- */
- private[sql] def merge(that: StructType): StructType =
- StructType.merge(this, that).asInstanceOf[StructType]
-
- private[spark] override def asNullable: StructType = {
- val newFields = fields.map {
- case StructField(name, dataType, nullable, metadata) =>
- StructField(name, dataType.asNullable, nullable = true, metadata)
- }
-
- StructType(newFields)
- }
-}
-
-
-object MapType {
- /**
- * Construct a [[MapType]] object with the given key type and value type.
- * The `valueContainsNull` is true.
- */
- def apply(keyType: DataType, valueType: DataType): MapType =
- MapType(keyType: DataType, valueType: DataType, valueContainsNull = true)
-}
-
-
-/**
- * :: DeveloperApi ::
- * The data type for Maps. Keys in a map are not allowed to have `null` values.
- *
- * Please use [[DataTypes.createMapType()]] to create a specific instance.
- *
- * @param keyType The data type of map keys.
- * @param valueType The data type of map values.
- * @param valueContainsNull Indicates if map values have `null` values.
- *
- * @group dataType
- */
-case class MapType(
- keyType: DataType,
- valueType: DataType,
- valueContainsNull: Boolean) extends DataType {
-
- /** No-arg constructor for kryo. */
- def this() = this(null, null, false)
-
- private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = {
- builder.append(s"$prefix-- key: ${keyType.typeName}\n")
- builder.append(s"$prefix-- value: ${valueType.typeName} " +
- s"(valueContainsNull = $valueContainsNull)\n")
- DataType.buildFormattedString(keyType, s"$prefix |", builder)
- DataType.buildFormattedString(valueType, s"$prefix |", builder)
- }
-
- override private[sql] def jsonValue: JValue =
- ("type" -> typeName) ~
- ("keyType" -> keyType.jsonValue) ~
- ("valueType" -> valueType.jsonValue) ~
- ("valueContainsNull" -> valueContainsNull)
-
- /**
- * The default size of a value of the MapType is
- * 100 * (the default size of the key type + the default size of the value type).
- * (We assume that there are 100 elements).
- */
- override def defaultSize: Int = 100 * (keyType.defaultSize + valueType.defaultSize)
-
- override def simpleString: String = s"map<${keyType.simpleString},${valueType.simpleString}>"
-
- private[spark] override def asNullable: MapType =
- MapType(keyType.asNullable, valueType.asNullable, valueContainsNull = true)
-}
-
-
-/**
- * ::DeveloperApi::
- * The data type for User Defined Types (UDTs).
- *
- * This interface allows a user to make their own classes more interoperable with SparkSQL;
- * e.g., by creating a [[UserDefinedType]] for a class X, it becomes possible to create
- * a `DataFrame` which has class X in the schema.
- *
- * For SparkSQL to recognize UDTs, the UDT must be annotated with
- * [[SQLUserDefinedType]].
- *
- * The conversion via `serialize` occurs when instantiating a `DataFrame` from another RDD.
- * The conversion via `deserialize` occurs when reading from a `DataFrame`.
- */
-@DeveloperApi
-abstract class UserDefinedType[UserType] extends DataType with Serializable {
-
- /** Underlying storage type for this UDT */
- def sqlType: DataType
-
- /** Paired Python UDT class, if exists. */
- def pyUDT: String = null
-
- /**
- * Convert the user type to a SQL datum
- *
- * TODO: Can we make this take obj: UserType? The issue is in
- * CatalystTypeConverters.convertToCatalyst, where we need to convert Any to UserType.
- */
- def serialize(obj: Any): Any
-
- /** Convert a SQL datum to the user type */
- def deserialize(datum: Any): UserType
-
- override private[sql] def jsonValue: JValue = {
- ("type" -> "udt") ~
- ("class" -> this.getClass.getName) ~
- ("pyClass" -> pyUDT) ~
- ("sqlType" -> sqlType.jsonValue)
- }
-
- /**
- * Class object for the UserType
- */
- def userClass: java.lang.Class[UserType]
-
- /**
- * The default size of a value of the UserDefinedType is 4096 bytes.
- */
- override def defaultSize: Int = 4096
-
- /**
- * For UDT, asNullable will not change the nullability of its internal sqlType and just returns
- * itself.
- */
- private[spark] override def asNullable: UserDefinedType[UserType] = this
-}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala
index 1a0a0e6154ad2..a652c70560990 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala
@@ -49,13 +49,14 @@ class SqlParserSuite extends FunSuite {
test("test long keyword") {
val parser = new SuperLongKeywordTestParser
- assert(TestCommand("NotRealCommand") === parser("ThisIsASuperLongKeyWordTest NotRealCommand"))
+ assert(TestCommand("NotRealCommand") ===
+ parser.parse("ThisIsASuperLongKeyWordTest NotRealCommand"))
}
test("test case insensitive") {
val parser = new CaseInsensitiveTestParser
- assert(TestCommand("NotRealCommand") === parser("EXECUTE NotRealCommand"))
- assert(TestCommand("NotRealCommand") === parser("execute NotRealCommand"))
- assert(TestCommand("NotRealCommand") === parser("exEcute NotRealCommand"))
+ assert(TestCommand("NotRealCommand") === parser.parse("EXECUTE NotRealCommand"))
+ assert(TestCommand("NotRealCommand") === parser.parse("execute NotRealCommand"))
+ assert(TestCommand("NotRealCommand") === parser.parse("exEcute NotRealCommand"))
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index 7c249215bd6b6..971e1ff5ec2b8 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -42,10 +42,10 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
def caseSensitiveAnalyze(plan: LogicalPlan): Unit =
- caseSensitiveAnalyzer.checkAnalysis(caseSensitiveAnalyzer(plan))
+ caseSensitiveAnalyzer.checkAnalysis(caseSensitiveAnalyzer.execute(plan))
def caseInsensitiveAnalyze(plan: LogicalPlan): Unit =
- caseInsensitiveAnalyzer.checkAnalysis(caseInsensitiveAnalyzer(plan))
+ caseInsensitiveAnalyzer.checkAnalysis(caseInsensitiveAnalyzer.execute(plan))
val testRelation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)())
val testRelation2 = LocalRelation(
@@ -82,7 +82,7 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
a.select(UnresolvedStar(None)).select('a).unionAll(b.select(UnresolvedStar(None)))
}
- assert(caseInsensitiveAnalyzer(plan).resolved)
+ assert(caseInsensitiveAnalyzer.execute(plan).resolved)
}
test("check project's resolved") {
@@ -98,11 +98,11 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
test("analyze project") {
assert(
- caseSensitiveAnalyzer(Project(Seq(UnresolvedAttribute("a")), testRelation)) ===
+ caseSensitiveAnalyzer.execute(Project(Seq(UnresolvedAttribute("a")), testRelation)) ===
Project(testRelation.output, testRelation))
assert(
- caseSensitiveAnalyzer(
+ caseSensitiveAnalyzer.execute(
Project(Seq(UnresolvedAttribute("TbL.a")),
UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) ===
Project(testRelation.output, testRelation))
@@ -115,13 +115,13 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
assert(e.getMessage().toLowerCase.contains("cannot resolve"))
assert(
- caseInsensitiveAnalyzer(
+ caseInsensitiveAnalyzer.execute(
Project(Seq(UnresolvedAttribute("TbL.a")),
UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) ===
Project(testRelation.output, testRelation))
assert(
- caseInsensitiveAnalyzer(
+ caseInsensitiveAnalyzer.execute(
Project(Seq(UnresolvedAttribute("tBl.a")),
UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) ===
Project(testRelation.output, testRelation))
@@ -134,13 +134,13 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
assert(e.getMessage == "Table Not Found: tAbLe")
assert(
- caseSensitiveAnalyzer(UnresolvedRelation(Seq("TaBlE"), None)) === testRelation)
+ caseSensitiveAnalyzer.execute(UnresolvedRelation(Seq("TaBlE"), None)) === testRelation)
assert(
- caseInsensitiveAnalyzer(UnresolvedRelation(Seq("tAbLe"), None)) === testRelation)
+ caseInsensitiveAnalyzer.execute(UnresolvedRelation(Seq("tAbLe"), None)) === testRelation)
assert(
- caseInsensitiveAnalyzer(UnresolvedRelation(Seq("TaBlE"), None)) === testRelation)
+ caseInsensitiveAnalyzer.execute(UnresolvedRelation(Seq("TaBlE"), None)) === testRelation)
}
def errorTest(
@@ -219,7 +219,7 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
AttributeReference("d", DecimalType.Unlimited)(),
AttributeReference("e", ShortType)())
- val plan = caseInsensitiveAnalyzer(
+ val plan = caseInsensitiveAnalyzer.execute(
testRelation2.select(
'a / Literal(2) as 'div1,
'a / 'b as 'div2,
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala
index 67bec999dfbd1..36b03d1c65e28 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala
@@ -48,12 +48,12 @@ class DecimalPrecisionSuite extends FunSuite with BeforeAndAfter {
private def checkType(expression: Expression, expectedType: DataType): Unit = {
val plan = Project(Seq(Alias(expression, "c")()), relation)
- assert(analyzer(plan).schema.fields(0).dataType === expectedType)
+ assert(analyzer.execute(plan).schema.fields(0).dataType === expectedType)
}
private def checkComparison(expression: Expression, expectedType: DataType): Unit = {
val plan = Project(Alias(expression, "c")() :: Nil, relation)
- val comparison = analyzer(plan).collect {
+ val comparison = analyzer.execute(plan).collect {
case Project(Alias(e: BinaryComparison, _) :: Nil, _) => e
}.head
assert(comparison.left.dataType === expectedType)
@@ -64,7 +64,7 @@ class DecimalPrecisionSuite extends FunSuite with BeforeAndAfter {
val plan =
Union(Project(Seq(Alias(left, "l")()), relation),
Project(Seq(Alias(right, "r")()), relation))
- val (l, r) = analyzer(plan).collect {
+ val (l, r) = analyzer.execute(plan).collect {
case Union(left, right) => (left.output.head, right.output.head)
}.head
assert(l.dataType === expectedType)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala
index ef3114fd4dbab..b5ebe4b38e337 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala
@@ -29,7 +29,7 @@ class GeneratedEvaluationSuite extends ExpressionEvaluationSuite {
expected: Any,
inputRow: Row = EmptyRow): Unit = {
val plan = try {
- GenerateMutableProjection(Alias(expression, s"Optimized($expression)")() :: Nil)()
+ GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil)()
} catch {
case e: Throwable =>
val evaluated = GenerateProjection.expressionEvaluator(expression)
@@ -56,10 +56,10 @@ class GeneratedEvaluationSuite extends ExpressionEvaluationSuite {
val futures = (1 to 20).map { _ =>
future {
- GeneratePredicate(EqualTo(Literal(1), Literal(1)))
- GenerateProjection(EqualTo(Literal(1), Literal(1)) :: Nil)
- GenerateMutableProjection(EqualTo(Literal(1), Literal(1)) :: Nil)
- GenerateOrdering(Add(Literal(1), Literal(1)).asc :: Nil)
+ GeneratePredicate.generate(EqualTo(Literal(1), Literal(1)))
+ GenerateProjection.generate(EqualTo(Literal(1), Literal(1)) :: Nil)
+ GenerateMutableProjection.generate(EqualTo(Literal(1), Literal(1)) :: Nil)
+ GenerateOrdering.generate(Add(Literal(1), Literal(1)).asc :: Nil)
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala
index bcc0c404d2cfb..97af2e0fd0502 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala
@@ -25,13 +25,13 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
*/
class GeneratedMutableEvaluationSuite extends ExpressionEvaluationSuite {
override def checkEvaluation(
- expression: Expression,
- expected: Any,
- inputRow: Row = EmptyRow): Unit = {
+ expression: Expression,
+ expected: Any,
+ inputRow: Row = EmptyRow): Unit = {
lazy val evaluated = GenerateProjection.expressionEvaluator(expression)
val plan = try {
- GenerateProjection(Alias(expression, s"Optimized($expression)")() :: Nil)
+ GenerateProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil)
} catch {
case e: Throwable =>
fail(
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala
index 72f06e26e05f1..6255578d7fa57 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala
@@ -61,7 +61,7 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper {
def checkCondition(input: Expression, expected: Expression): Unit = {
val plan = testRelation.where(input).analyze
- val actual = Optimize(plan).expressions.head
+ val actual = Optimize.execute(plan).expressions.head
compareConditions(actual, expected)
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala
index e2ae0d25db1a5..2d16d668fd522 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala
@@ -44,7 +44,7 @@ class CombiningLimitsSuite extends PlanTest {
.limit(10)
.limit(5)
- val optimized = Optimize(originalQuery.analyze)
+ val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer =
testRelation
.select('a)
@@ -61,7 +61,7 @@ class CombiningLimitsSuite extends PlanTest {
.limit(7)
.limit(5)
- val optimized = Optimize(originalQuery.analyze)
+ val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer =
testRelation
.select('a)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
index 4396bd0dda9a9..14b28e8402610 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
@@ -47,7 +47,7 @@ class ConstantFoldingSuite extends PlanTest {
.subquery('y)
.select('a)
- val optimized = Optimize(originalQuery.analyze)
+ val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer =
testRelation
.select('a.attr)
@@ -74,7 +74,7 @@ class ConstantFoldingSuite extends PlanTest {
Literal(2) * Literal(3) - Literal(6) / (Literal(4) - Literal(2))
)(Literal(9) / Literal(3) as Symbol("9/3"))
- val optimized = Optimize(originalQuery.analyze)
+ val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer =
testRelation
@@ -99,7 +99,7 @@ class ConstantFoldingSuite extends PlanTest {
Literal(2) * 'a + Literal(4) as Symbol("c3"),
'a * (Literal(3) + Literal(4)) as Symbol("c4"))
- val optimized = Optimize(originalQuery.analyze)
+ val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer =
testRelation
@@ -127,7 +127,7 @@ class ConstantFoldingSuite extends PlanTest {
(Literal(1) === Literal(1) || 'b > 1) &&
(Literal(1) === Literal(2) || 'b < 10)))
- val optimized = Optimize(originalQuery.analyze)
+ val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer =
testRelation
@@ -144,7 +144,7 @@ class ConstantFoldingSuite extends PlanTest {
Cast(Literal("2"), IntegerType) + Literal(3) + 'a as Symbol("c1"),
Coalesce(Seq(Cast(Literal("abc"), IntegerType), Literal(3))) as Symbol("c2"))
- val optimized = Optimize(originalQuery.analyze)
+ val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer =
testRelation
@@ -163,7 +163,7 @@ class ConstantFoldingSuite extends PlanTest {
Rand + Literal(1) as Symbol("c1"),
Sum('a) as Symbol("c2"))
- val optimized = Optimize(originalQuery.analyze)
+ val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer =
testRelation
@@ -210,7 +210,7 @@ class ConstantFoldingSuite extends PlanTest {
Contains("abc", Literal.create(null, StringType)) as 'c20
)
- val optimized = Optimize(originalQuery.analyze)
+ val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer =
testRelation
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertToLocalRelationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertToLocalRelationSuite.scala
index cf42d43823399..6841bd9890c97 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertToLocalRelationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertToLocalRelationSuite.scala
@@ -49,7 +49,7 @@ class ConvertToLocalRelationSuite extends PlanTest {
UnresolvedAttribute("a").as("a1"),
(UnresolvedAttribute("b") + 1).as("b1"))
- val optimized = Optimize(projectOnLocal.analyze)
+ val optimized = Optimize.execute(projectOnLocal.analyze)
comparePlans(optimized, correctAnswer)
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExpressionOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExpressionOptimizationSuite.scala
index 2f3704be59a9d..a4a3a66b8b229 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExpressionOptimizationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExpressionOptimizationSuite.scala
@@ -30,7 +30,7 @@ class ExpressionOptimizationSuite extends ExpressionEvaluationSuite {
expected: Any,
inputRow: Row = EmptyRow): Unit = {
val plan = Project(Alias(expression, s"Optimized($expression)")() :: Nil, OneRowRelation)
- val optimizedPlan = DefaultOptimizer(plan)
+ val optimizedPlan = DefaultOptimizer.execute(plan)
super.checkEvaluation(optimizedPlan.expressions.head, expected, inputRow)
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
index 45cf695d20b01..aa9708b164efa 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
@@ -50,7 +50,7 @@ class FilterPushdownSuite extends PlanTest {
.subquery('y)
.select('a)
- val optimized = Optimize(originalQuery.analyze)
+ val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer =
testRelation
.select('a.attr)
@@ -65,7 +65,7 @@ class FilterPushdownSuite extends PlanTest {
.groupBy('a)('a, Count('b))
.select('a)
- val optimized = Optimize(originalQuery.analyze)
+ val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer =
testRelation
.select('a)
@@ -81,7 +81,7 @@ class FilterPushdownSuite extends PlanTest {
.groupBy('a)('a as 'c, Count('b))
.select('c)
- val optimized = Optimize(originalQuery.analyze)
+ val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer =
testRelation
.select('a)
@@ -98,7 +98,7 @@ class FilterPushdownSuite extends PlanTest {
.select('a)
.where('a === 1)
- val optimized = Optimize(originalQuery.analyze)
+ val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer =
testRelation
.where('a === 1)
@@ -115,7 +115,7 @@ class FilterPushdownSuite extends PlanTest {
.where('e === 1)
.analyze
- val optimized = Optimize(originalQuery.analyze)
+ val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer =
testRelation
.where('a + 'b === 1)
@@ -131,7 +131,7 @@ class FilterPushdownSuite extends PlanTest {
.where('a === 1)
.where('a === 2)
- val optimized = Optimize(originalQuery.analyze)
+ val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer =
testRelation
.where('a === 1 && 'a === 2)
@@ -152,7 +152,7 @@ class FilterPushdownSuite extends PlanTest {
.where("y.b".attr === 2)
}
- val optimized = Optimize(originalQuery.analyze)
+ val optimized = Optimize.execute(originalQuery.analyze)
val left = testRelation.where('b === 1)
val right = testRelation.where('b === 2)
val correctAnswer =
@@ -170,7 +170,7 @@ class FilterPushdownSuite extends PlanTest {
.where("x.b".attr === 1)
}
- val optimized = Optimize(originalQuery.analyze)
+ val optimized = Optimize.execute(originalQuery.analyze)
val left = testRelation.where('b === 1)
val right = testRelation
val correctAnswer =
@@ -188,7 +188,7 @@ class FilterPushdownSuite extends PlanTest {
.where("x.b".attr === 1 && "y.b".attr === 2)
}
- val optimized = Optimize(originalQuery.analyze)
+ val optimized = Optimize.execute(originalQuery.analyze)
val left = testRelation.where('b === 1)
val right = testRelation.where('b === 2)
val correctAnswer =
@@ -206,7 +206,7 @@ class FilterPushdownSuite extends PlanTest {
.where("x.b".attr === 1 && "y.b".attr === 2)
}
- val optimized = Optimize(originalQuery.analyze)
+ val optimized = Optimize.execute(originalQuery.analyze)
val left = testRelation.where('b === 1)
val correctAnswer =
left.join(y, LeftOuter).where("y.b".attr === 2).analyze
@@ -223,7 +223,7 @@ class FilterPushdownSuite extends PlanTest {
.where("x.b".attr === 1 && "y.b".attr === 2)
}
- val optimized = Optimize(originalQuery.analyze)
+ val optimized = Optimize.execute(originalQuery.analyze)
val right = testRelation.where('b === 2).subquery('d)
val correctAnswer =
x.join(right, RightOuter).where("x.b".attr === 1).analyze
@@ -240,7 +240,7 @@ class FilterPushdownSuite extends PlanTest {
.where("x.b".attr === 2 && "y.b".attr === 2)
}
- val optimized = Optimize(originalQuery.analyze)
+ val optimized = Optimize.execute(originalQuery.analyze)
val left = testRelation.where('b === 2).subquery('d)
val correctAnswer =
left.join(y, LeftOuter, Some("d.b".attr === 1)).where("y.b".attr === 2).analyze
@@ -257,7 +257,7 @@ class FilterPushdownSuite extends PlanTest {
.where("x.b".attr === 2 && "y.b".attr === 2)
}
- val optimized = Optimize(originalQuery.analyze)
+ val optimized = Optimize.execute(originalQuery.analyze)
val right = testRelation.where('b === 2).subquery('d)
val correctAnswer =
x.join(right, RightOuter, Some("d.b".attr === 1)).where("x.b".attr === 2).analyze
@@ -274,7 +274,7 @@ class FilterPushdownSuite extends PlanTest {
.where("x.b".attr === 2 && "y.b".attr === 2)
}
- val optimized = Optimize(originalQuery.analyze)
+ val optimized = Optimize.execute(originalQuery.analyze)
val left = testRelation.where('b === 2).subquery('l)
val right = testRelation.where('b === 1).subquery('r)
val correctAnswer =
@@ -292,7 +292,7 @@ class FilterPushdownSuite extends PlanTest {
.where("x.b".attr === 2 && "y.b".attr === 2)
}
- val optimized = Optimize(originalQuery.analyze)
+ val optimized = Optimize.execute(originalQuery.analyze)
val right = testRelation.where('b === 2).subquery('r)
val correctAnswer =
x.join(right, RightOuter, Some("r.b".attr === 1)).where("x.b".attr === 2).analyze
@@ -309,7 +309,7 @@ class FilterPushdownSuite extends PlanTest {
.where("x.b".attr === 2 && "y.b".attr === 2 && "x.c".attr === "y.c".attr)
}
- val optimized = Optimize(originalQuery.analyze)
+ val optimized = Optimize.execute(originalQuery.analyze)
val left = testRelation.where('b === 2).subquery('l)
val right = testRelation.where('b === 1).subquery('r)
val correctAnswer =
@@ -327,7 +327,7 @@ class FilterPushdownSuite extends PlanTest {
.where("x.b".attr === 2 && "y.b".attr === 2 && "x.c".attr === "y.c".attr)
}
- val optimized = Optimize(originalQuery.analyze)
+ val optimized = Optimize.execute(originalQuery.analyze)
val left = testRelation.subquery('l)
val right = testRelation.where('b === 2).subquery('r)
val correctAnswer =
@@ -346,7 +346,7 @@ class FilterPushdownSuite extends PlanTest {
.where("x.b".attr === 2 && "y.b".attr === 2 && "x.c".attr === "y.c".attr)
}
- val optimized = Optimize(originalQuery.analyze)
+ val optimized = Optimize.execute(originalQuery.analyze)
val left = testRelation.where('b === 2).subquery('l)
val right = testRelation.where('b === 1).subquery('r)
val correctAnswer =
@@ -365,7 +365,7 @@ class FilterPushdownSuite extends PlanTest {
.where("x.b".attr === 2 && "y.b".attr === 2 && "x.c".attr === "y.c".attr)
}
- val optimized = Optimize(originalQuery.analyze)
+ val optimized = Optimize.execute(originalQuery.analyze)
val left = testRelation.where('a === 3).subquery('l)
val right = testRelation.where('b === 2).subquery('r)
val correctAnswer =
@@ -382,7 +382,7 @@ class FilterPushdownSuite extends PlanTest {
val originalQuery = {
x.join(y, condition = Some("x.b".attr === "y.b".attr))
}
- val optimized = Optimize(originalQuery.analyze)
+ val optimized = Optimize.execute(originalQuery.analyze)
comparePlans(analysis.EliminateSubQueries(originalQuery.analyze), optimized)
}
@@ -396,7 +396,7 @@ class FilterPushdownSuite extends PlanTest {
.where(("x.b".attr === "y.b".attr) && ("x.a".attr === 1) && ("y.a".attr === 1))
}
- val optimized = Optimize(originalQuery.analyze)
+ val optimized = Optimize.execute(originalQuery.analyze)
val left = testRelation.where('a === 1).subquery('x)
val right = testRelation.where('a === 1).subquery('y)
val correctAnswer =
@@ -415,7 +415,7 @@ class FilterPushdownSuite extends PlanTest {
.where(("x.b".attr === "y.b".attr) && ("x.a".attr === 1))
}
- val optimized = Optimize(originalQuery.analyze)
+ val optimized = Optimize.execute(originalQuery.analyze)
val left = testRelation.where('a === 1).subquery('x)
val right = testRelation.subquery('y)
val correctAnswer =
@@ -436,7 +436,7 @@ class FilterPushdownSuite extends PlanTest {
("z.a".attr >= 3) && ("z.a".attr === "x.b".attr))
}
- val optimized = Optimize(originalQuery.analyze)
+ val optimized = Optimize.execute(originalQuery.analyze)
val lleft = testRelation.where('a >= 3).subquery('z)
val left = testRelation.where('a === 1).subquery('x)
val right = testRelation.subquery('y)
@@ -457,7 +457,7 @@ class FilterPushdownSuite extends PlanTest {
.generate(Explode('c_arr), true, false, Some("arr"))
.where(('b >= 5) && ('a > 6))
}
- val optimized = Optimize(originalQuery.analyze)
+ val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer = {
testRelationWithArrayType
.where(('b >= 5) && ('a > 6))
@@ -474,7 +474,7 @@ class FilterPushdownSuite extends PlanTest {
.generate(generator, true, false, Some("arr"))
.where(('b >= 5) && ('c > 6))
}
- val optimized = Optimize(originalQuery.analyze)
+ val optimized = Optimize.execute(originalQuery.analyze)
val referenceResult = {
testRelationWithArrayType
.where('b >= 5)
@@ -502,7 +502,7 @@ class FilterPushdownSuite extends PlanTest {
.generate(Explode('c_arr), true, false, Some("arr"))
.where(('c > 6) || ('b > 5)).analyze
}
- val optimized = Optimize(originalQuery)
+ val optimized = Optimize.execute(originalQuery)
comparePlans(optimized, originalQuery)
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala
index b10577c8001e2..b3df487c84dc8 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala
@@ -41,7 +41,7 @@ class LikeSimplificationSuite extends PlanTest {
testRelation
.where(('a like "abc%") || ('a like "abc\\%"))
- val optimized = Optimize(originalQuery.analyze)
+ val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer = testRelation
.where(StartsWith('a, "abc") || ('a like "abc\\%"))
.analyze
@@ -54,7 +54,7 @@ class LikeSimplificationSuite extends PlanTest {
testRelation
.where('a like "%xyz")
- val optimized = Optimize(originalQuery.analyze)
+ val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer = testRelation
.where(EndsWith('a, "xyz"))
.analyze
@@ -67,7 +67,7 @@ class LikeSimplificationSuite extends PlanTest {
testRelation
.where(('a like "%mn%") || ('a like "%mn\\%"))
- val optimized = Optimize(originalQuery.analyze)
+ val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer = testRelation
.where(Contains('a, "mn") || ('a like "%mn\\%"))
.analyze
@@ -80,7 +80,7 @@ class LikeSimplificationSuite extends PlanTest {
testRelation
.where(('a like "") || ('a like "abc"))
- val optimized = Optimize(originalQuery.analyze)
+ val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer = testRelation
.where(('a === "") || ('a === "abc"))
.analyze
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala
index 966bc9ada1e6e..3eb399e68e70c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala
@@ -49,7 +49,7 @@ class OptimizeInSuite extends PlanTest {
.where(In(UnresolvedAttribute("a"), Seq(Literal(1),Literal(2))))
.analyze
- val optimized = Optimize(originalQuery.analyze)
+ val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer =
testRelation
.where(InSet(UnresolvedAttribute("a"), HashSet[Any]() + 1 + 2))
@@ -64,7 +64,7 @@ class OptimizeInSuite extends PlanTest {
.where(In(UnresolvedAttribute("a"), Seq(Literal(1),Literal(2), UnresolvedAttribute("b"))))
.analyze
- val optimized = Optimize(originalQuery.analyze)
+ val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer =
testRelation
.where(In(UnresolvedAttribute("a"), Seq(Literal(1),Literal(2), UnresolvedAttribute("b"))))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCaseConversionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCaseConversionExpressionsSuite.scala
index 22992fb6f50d4..6b1e53cd42b24 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCaseConversionExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCaseConversionExpressionsSuite.scala
@@ -41,7 +41,7 @@ class SimplifyCaseConversionExpressionsSuite extends PlanTest {
testRelation
.select(Upper(Upper('a)) as 'u)
- val optimized = Optimize(originalQuery.analyze)
+ val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer =
testRelation
.select(Upper('a) as 'u)
@@ -55,7 +55,7 @@ class SimplifyCaseConversionExpressionsSuite extends PlanTest {
testRelation
.select(Upper(Lower('a)) as 'u)
- val optimized = Optimize(originalQuery.analyze)
+ val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer =
testRelation
.select(Upper('a) as 'u)
@@ -69,7 +69,7 @@ class SimplifyCaseConversionExpressionsSuite extends PlanTest {
testRelation
.select(Lower(Upper('a)) as 'l)
- val optimized = Optimize(originalQuery.analyze)
+ val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer = testRelation
.select(Lower('a) as 'l)
.analyze
@@ -82,7 +82,7 @@ class SimplifyCaseConversionExpressionsSuite extends PlanTest {
testRelation
.select(Lower(Lower('a)) as 'l)
- val optimized = Optimize(originalQuery.analyze)
+ val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer = testRelation
.select(Lower('a) as 'l)
.analyze
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala
index a54751dfa9a12..a3ad200800b02 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala
@@ -17,10 +17,9 @@
package org.apache.spark.sql.catalyst.optimizer
-import org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries
+import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.catalyst.plans.{PlanTest, LeftOuter, RightOuter}
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.dsl.expressions._
@@ -41,7 +40,7 @@ class UnionPushdownSuite extends PlanTest {
test("union: filter to each side") {
val query = testUnion.where('a === 1)
- val optimized = Optimize(query.analyze)
+ val optimized = Optimize.execute(query.analyze)
val correctAnswer =
Union(testRelation.where('a === 1), testRelation2.where('d === 1)).analyze
@@ -52,7 +51,7 @@ class UnionPushdownSuite extends PlanTest {
test("union: project to each side") {
val query = testUnion.select('b)
- val optimized = Optimize(query.analyze)
+ val optimized = Optimize.execute(query.analyze)
val correctAnswer =
Union(testRelation.select('b), testRelation2.select('e)).analyze
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala
index 4b2d45584045f..2a641c63f87bb 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala
@@ -34,7 +34,7 @@ class RuleExecutorSuite extends FunSuite {
val batches = Batch("once", Once, DecrementLiterals) :: Nil
}
- assert(ApplyOnce(Literal(10)) === Literal(9))
+ assert(ApplyOnce.execute(Literal(10)) === Literal(9))
}
test("to fixed point") {
@@ -42,7 +42,7 @@ class RuleExecutorSuite extends FunSuite {
val batches = Batch("fixedPoint", FixedPoint(100), DecrementLiterals) :: Nil
}
- assert(ToFixedPoint(Literal(10)) === Literal(0))
+ assert(ToFixedPoint.execute(Literal(10)) === Literal(0))
}
test("to maxIterations") {
@@ -50,6 +50,6 @@ class RuleExecutorSuite extends FunSuite {
val batches = Batch("fixedPoint", FixedPoint(10), DecrementLiterals) :: Nil
}
- assert(ToFixedPoint(Literal(100)) === Literal(90))
+ assert(ToFixedPoint.execute(Literal(100)) === Literal(90))
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala
index 169125264a803..3e7cf7cbb5e63 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala
@@ -23,13 +23,13 @@ class DataTypeParserSuite extends FunSuite {
def checkDataType(dataTypeString: String, expectedDataType: DataType): Unit = {
test(s"parse ${dataTypeString.replace("\n", "")}") {
- assert(DataTypeParser(dataTypeString) === expectedDataType)
+ assert(DataTypeParser.parse(dataTypeString) === expectedDataType)
}
}
def unsupported(dataTypeString: String): Unit = {
test(s"$dataTypeString is not supported") {
- intercept[DataTypeException](DataTypeParser(dataTypeString))
+ intercept[DataTypeException](DataTypeParser.parse(dataTypeString))
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index edb229c059e6b..33f9d0b37d006 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -647,7 +647,7 @@ class Column(protected[sql] val expr: Expression) extends Logging {
*
* @group expr_ops
*/
- def cast(to: String): Column = cast(DataTypeParser(to))
+ def cast(to: String): Column = cast(DataTypeParser.parse(to))
/**
* Returns an ordering used in sorting.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 03d9834d1d131..ca6ae482eb2ab 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -342,6 +342,43 @@ class DataFrame private[sql](
Join(logicalPlan, right.logicalPlan, joinType = Inner, None)
}
+ /**
+ * Inner equi-join with another [[DataFrame]] using the given column.
+ *
+ * Different from other join functions, the join column will only appear once in the output,
+ * i.e. similar to SQL's `JOIN USING` syntax.
+ *
+ * {{{
+ * // Joining df1 and df2 using the column "user_id"
+ * df1.join(df2, "user_id")
+ * }}}
+ *
+ * Note that if you perform a self-join using this function without aliasing the input
+ * [[DataFrame]]s, you will NOT be able to reference any columns after the join, since
+ * there is no way to disambiguate which side of the join you would like to reference.
+ *
+ * @param right Right side of the join operation.
+ * @param usingColumn Name of the column to join on. This column must exist on both sides.
+ * @group dfops
+ */
+ def join(right: DataFrame, usingColumn: String): DataFrame = {
+ // Analyze the self join. The assumption is that the analyzer will disambiguate left vs right
+ // by creating a new instance for one of the branch.
+ val joined = sqlContext.executePlan(
+ Join(logicalPlan, right.logicalPlan, joinType = Inner, None)).analyzed.asInstanceOf[Join]
+
+ // Project only one of the join column.
+ val joinedCol = joined.right.resolve(usingColumn)
+ Project(
+ joined.output.filterNot(_ == joinedCol),
+ Join(
+ joined.left,
+ joined.right,
+ joinType = Inner,
+ Some(EqualTo(joined.left.resolve(usingColumn), joined.right.resolve(usingColumn))))
+ )
+ }
+
/**
* Inner join with another [[DataFrame]], using the given join expression.
*
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index bcd20c06c6dca..a279b0f07c38a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -132,16 +132,16 @@ class SQLContext(@transient val sparkContext: SparkContext)
protected[sql] lazy val optimizer: Optimizer = DefaultOptimizer
@transient
- protected[sql] val ddlParser = new DDLParser(sqlParser.apply(_))
+ protected[sql] val ddlParser = new DDLParser(sqlParser.parse(_))
@transient
protected[sql] val sqlParser = {
val fallback = new catalyst.SqlParser
- new SparkSQLParser(fallback(_))
+ new SparkSQLParser(fallback.parse(_))
}
protected[sql] def parseSql(sql: String): LogicalPlan = {
- ddlParser(sql, false).getOrElse(sqlParser(sql))
+ ddlParser.parse(sql, false).getOrElse(sqlParser.parse(sql))
}
protected[sql] def executeSql(sql: String): this.QueryExecution = executePlan(parseSql(sql))
@@ -1120,12 +1120,12 @@ class SQLContext(@transient val sparkContext: SparkContext)
protected[sql] class QueryExecution(val logical: LogicalPlan) {
def assertAnalyzed(): Unit = analyzer.checkAnalysis(analyzed)
- lazy val analyzed: LogicalPlan = analyzer(logical)
+ lazy val analyzed: LogicalPlan = analyzer.execute(logical)
lazy val withCachedData: LogicalPlan = {
assertAnalyzed()
cacheManager.useCachedData(analyzed)
}
- lazy val optimizedPlan: LogicalPlan = optimizer(withCachedData)
+ lazy val optimizedPlan: LogicalPlan = optimizer.execute(withCachedData)
// TODO: Don't just pick the first one...
lazy val sparkPlan: SparkPlan = {
@@ -1134,7 +1134,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
}
// executedPlan should not be used to initialize any SparkPlan. It should be
// only used for execution.
- lazy val executedPlan: SparkPlan = prepareForExecution(sparkPlan)
+ lazy val executedPlan: SparkPlan = prepareForExecution.execute(sparkPlan)
/** Internal version of the RDD. Avoids copies and has no schema */
lazy val toRdd: RDD[Row] = executedPlan.execute()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala
index f615fb33a7c35..64449b2659b4b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala
@@ -61,7 +61,7 @@ private[sql] abstract class BasicColumnAccessor[T <: DataType, JvmType](
protected def underlyingBuffer = buffer
}
-private[sql] abstract class NativeColumnAccessor[T <: NativeType](
+private[sql] abstract class NativeColumnAccessor[T <: AtomicType](
override protected val buffer: ByteBuffer,
override protected val columnType: NativeColumnType[T])
extends BasicColumnAccessor(buffer, columnType)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala
index 00ed70430b84d..aa10af400c815 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala
@@ -84,10 +84,10 @@ private[sql] abstract class ComplexColumnBuilder[T <: DataType, JvmType](
extends BasicColumnBuilder[T, JvmType](columnStats, columnType)
with NullableColumnBuilder
-private[sql] abstract class NativeColumnBuilder[T <: NativeType](
+private[sql] abstract class NativeColumnBuilder[T <: AtomicType](
override val columnStats: ColumnStats,
override val columnType: NativeColumnType[T])
- extends BasicColumnBuilder[T, T#JvmType](columnStats, columnType)
+ extends BasicColumnBuilder[T, T#InternalType](columnStats, columnType)
with NullableColumnBuilder
with AllCompressionSchemes
with CompressibleColumnBuilder[T]
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
index 1b9e0df2dcb5e..20be5ca9d0046 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
@@ -101,16 +101,16 @@ private[sql] sealed abstract class ColumnType[T <: DataType, JvmType](
override def toString: String = getClass.getSimpleName.stripSuffix("$")
}
-private[sql] abstract class NativeColumnType[T <: NativeType](
+private[sql] abstract class NativeColumnType[T <: AtomicType](
val dataType: T,
typeId: Int,
defaultSize: Int)
- extends ColumnType[T, T#JvmType](typeId, defaultSize) {
+ extends ColumnType[T, T#InternalType](typeId, defaultSize) {
/**
* Scala TypeTag. Can be used to create primitive arrays and hash tables.
*/
- def scalaTag: TypeTag[dataType.JvmType] = dataType.tag
+ def scalaTag: TypeTag[dataType.InternalType] = dataType.tag
}
private[sql] object INT extends NativeColumnType(IntegerType, 0, 4) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala
index d0b602a834dfe..cb205defbb1ad 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala
@@ -19,9 +19,9 @@ package org.apache.spark.sql.columnar.compression
import org.apache.spark.sql.catalyst.expressions.MutableRow
import org.apache.spark.sql.columnar.{ColumnAccessor, NativeColumnAccessor}
-import org.apache.spark.sql.types.NativeType
+import org.apache.spark.sql.types.AtomicType
-private[sql] trait CompressibleColumnAccessor[T <: NativeType] extends ColumnAccessor {
+private[sql] trait CompressibleColumnAccessor[T <: AtomicType] extends ColumnAccessor {
this: NativeColumnAccessor[T] =>
private var decoder: Decoder[T] = _
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala
index b9cfc5df550d1..8e2a1af6dae78 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala
@@ -22,7 +22,7 @@ import java.nio.{ByteBuffer, ByteOrder}
import org.apache.spark.Logging
import org.apache.spark.sql.Row
import org.apache.spark.sql.columnar.{ColumnBuilder, NativeColumnBuilder}
-import org.apache.spark.sql.types.NativeType
+import org.apache.spark.sql.types.AtomicType
/**
* A stackable trait that builds optionally compressed byte buffer for a column. Memory layout of
@@ -41,7 +41,7 @@ import org.apache.spark.sql.types.NativeType
* header body
* }}}
*/
-private[sql] trait CompressibleColumnBuilder[T <: NativeType]
+private[sql] trait CompressibleColumnBuilder[T <: AtomicType]
extends ColumnBuilder with Logging {
this: NativeColumnBuilder[T] with WithCompressionSchemes =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala
index 879d29bcfa6f6..17c2d9b111188 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala
@@ -22,9 +22,9 @@ import java.nio.{ByteBuffer, ByteOrder}
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.MutableRow
import org.apache.spark.sql.columnar.{ColumnType, NativeColumnType}
-import org.apache.spark.sql.types.NativeType
+import org.apache.spark.sql.types.AtomicType
-private[sql] trait Encoder[T <: NativeType] {
+private[sql] trait Encoder[T <: AtomicType] {
def gatherCompressibilityStats(row: Row, ordinal: Int): Unit = {}
def compressedSize: Int
@@ -38,7 +38,7 @@ private[sql] trait Encoder[T <: NativeType] {
def compress(from: ByteBuffer, to: ByteBuffer): ByteBuffer
}
-private[sql] trait Decoder[T <: NativeType] {
+private[sql] trait Decoder[T <: AtomicType] {
def next(row: MutableRow, ordinal: Int): Unit
def hasNext: Boolean
@@ -49,9 +49,9 @@ private[sql] trait CompressionScheme {
def supports(columnType: ColumnType[_, _]): Boolean
- def encoder[T <: NativeType](columnType: NativeColumnType[T]): Encoder[T]
+ def encoder[T <: AtomicType](columnType: NativeColumnType[T]): Encoder[T]
- def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]): Decoder[T]
+ def decoder[T <: AtomicType](buffer: ByteBuffer, columnType: NativeColumnType[T]): Decoder[T]
}
private[sql] trait WithCompressionSchemes {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala
index 8727d71c48bb7..534ae90ddbc8b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala
@@ -35,16 +35,16 @@ private[sql] case object PassThrough extends CompressionScheme {
override def supports(columnType: ColumnType[_, _]): Boolean = true
- override def encoder[T <: NativeType](columnType: NativeColumnType[T]): Encoder[T] = {
+ override def encoder[T <: AtomicType](columnType: NativeColumnType[T]): Encoder[T] = {
new this.Encoder[T](columnType)
}
- override def decoder[T <: NativeType](
+ override def decoder[T <: AtomicType](
buffer: ByteBuffer, columnType: NativeColumnType[T]): Decoder[T] = {
new this.Decoder(buffer, columnType)
}
- class Encoder[T <: NativeType](columnType: NativeColumnType[T]) extends compression.Encoder[T] {
+ class Encoder[T <: AtomicType](columnType: NativeColumnType[T]) extends compression.Encoder[T] {
override def uncompressedSize: Int = 0
override def compressedSize: Int = 0
@@ -56,7 +56,7 @@ private[sql] case object PassThrough extends CompressionScheme {
}
}
- class Decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T])
+ class Decoder[T <: AtomicType](buffer: ByteBuffer, columnType: NativeColumnType[T])
extends compression.Decoder[T] {
override def next(row: MutableRow, ordinal: Int): Unit = {
@@ -70,11 +70,11 @@ private[sql] case object PassThrough extends CompressionScheme {
private[sql] case object RunLengthEncoding extends CompressionScheme {
override val typeId = 1
- override def encoder[T <: NativeType](columnType: NativeColumnType[T]): Encoder[T] = {
+ override def encoder[T <: AtomicType](columnType: NativeColumnType[T]): Encoder[T] = {
new this.Encoder[T](columnType)
}
- override def decoder[T <: NativeType](
+ override def decoder[T <: AtomicType](
buffer: ByteBuffer, columnType: NativeColumnType[T]): Decoder[T] = {
new this.Decoder(buffer, columnType)
}
@@ -84,7 +84,7 @@ private[sql] case object RunLengthEncoding extends CompressionScheme {
case _ => false
}
- class Encoder[T <: NativeType](columnType: NativeColumnType[T]) extends compression.Encoder[T] {
+ class Encoder[T <: AtomicType](columnType: NativeColumnType[T]) extends compression.Encoder[T] {
private var _uncompressedSize = 0
private var _compressedSize = 0
@@ -152,12 +152,12 @@ private[sql] case object RunLengthEncoding extends CompressionScheme {
}
}
- class Decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T])
+ class Decoder[T <: AtomicType](buffer: ByteBuffer, columnType: NativeColumnType[T])
extends compression.Decoder[T] {
private var run = 0
private var valueCount = 0
- private var currentValue: T#JvmType = _
+ private var currentValue: T#InternalType = _
override def next(row: MutableRow, ordinal: Int): Unit = {
if (valueCount == run) {
@@ -181,12 +181,12 @@ private[sql] case object DictionaryEncoding extends CompressionScheme {
// 32K unique values allowed
val MAX_DICT_SIZE = Short.MaxValue
- override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T])
+ override def decoder[T <: AtomicType](buffer: ByteBuffer, columnType: NativeColumnType[T])
: Decoder[T] = {
new this.Decoder(buffer, columnType)
}
- override def encoder[T <: NativeType](columnType: NativeColumnType[T]): Encoder[T] = {
+ override def encoder[T <: AtomicType](columnType: NativeColumnType[T]): Encoder[T] = {
new this.Encoder[T](columnType)
}
@@ -195,7 +195,7 @@ private[sql] case object DictionaryEncoding extends CompressionScheme {
case _ => false
}
- class Encoder[T <: NativeType](columnType: NativeColumnType[T]) extends compression.Encoder[T] {
+ class Encoder[T <: AtomicType](columnType: NativeColumnType[T]) extends compression.Encoder[T] {
// Size of the input, uncompressed, in bytes. Note that we only count until the dictionary
// overflows.
private var _uncompressedSize = 0
@@ -208,7 +208,7 @@ private[sql] case object DictionaryEncoding extends CompressionScheme {
private var count = 0
// The reverse mapping of _dictionary, i.e. mapping encoded integer to the value itself.
- private var values = new mutable.ArrayBuffer[T#JvmType](1024)
+ private var values = new mutable.ArrayBuffer[T#InternalType](1024)
// The dictionary that maps a value to the encoded short integer.
private val dictionary = mutable.HashMap.empty[Any, Short]
@@ -268,14 +268,14 @@ private[sql] case object DictionaryEncoding extends CompressionScheme {
override def compressedSize: Int = if (overflow) Int.MaxValue else dictionarySize + count * 2
}
- class Decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T])
+ class Decoder[T <: AtomicType](buffer: ByteBuffer, columnType: NativeColumnType[T])
extends compression.Decoder[T] {
private val dictionary = {
// TODO Can we clean up this mess? Maybe move this to `DataType`?
implicit val classTag = {
val mirror = runtimeMirror(Utils.getSparkClassLoader)
- ClassTag[T#JvmType](mirror.runtimeClass(columnType.scalaTag.tpe))
+ ClassTag[T#InternalType](mirror.runtimeClass(columnType.scalaTag.tpe))
}
Array.fill(buffer.getInt()) {
@@ -296,12 +296,12 @@ private[sql] case object BooleanBitSet extends CompressionScheme {
val BITS_PER_LONG = 64
- override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T])
+ override def decoder[T <: AtomicType](buffer: ByteBuffer, columnType: NativeColumnType[T])
: compression.Decoder[T] = {
new this.Decoder(buffer).asInstanceOf[compression.Decoder[T]]
}
- override def encoder[T <: NativeType](columnType: NativeColumnType[T]): compression.Encoder[T] = {
+ override def encoder[T <: AtomicType](columnType: NativeColumnType[T]): compression.Encoder[T] = {
(new this.Encoder).asInstanceOf[compression.Encoder[T]]
}
@@ -384,12 +384,12 @@ private[sql] case object BooleanBitSet extends CompressionScheme {
private[sql] case object IntDelta extends CompressionScheme {
override def typeId: Int = 4
- override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T])
+ override def decoder[T <: AtomicType](buffer: ByteBuffer, columnType: NativeColumnType[T])
: compression.Decoder[T] = {
new Decoder(buffer, INT).asInstanceOf[compression.Decoder[T]]
}
- override def encoder[T <: NativeType](columnType: NativeColumnType[T]): compression.Encoder[T] = {
+ override def encoder[T <: AtomicType](columnType: NativeColumnType[T]): compression.Encoder[T] = {
(new Encoder).asInstanceOf[compression.Encoder[T]]
}
@@ -464,12 +464,12 @@ private[sql] case object IntDelta extends CompressionScheme {
private[sql] case object LongDelta extends CompressionScheme {
override def typeId: Int = 5
- override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T])
+ override def decoder[T <: AtomicType](buffer: ByteBuffer, columnType: NativeColumnType[T])
: compression.Decoder[T] = {
new Decoder(buffer, LONG).asInstanceOf[compression.Decoder[T]]
}
- override def encoder[T <: NativeType](columnType: NativeColumnType[T]): compression.Encoder[T] = {
+ override def encoder[T <: AtomicType](columnType: NativeColumnType[T]): compression.Encoder[T] = {
(new Encoder).asInstanceOf[compression.Encoder[T]]
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index e159ffe66cb24..59c89800da00f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -144,7 +144,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
log.debug(
s"Creating Projection: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled")
if (codegenEnabled) {
- GenerateProjection(expressions, inputSchema)
+ GenerateProjection.generate(expressions, inputSchema)
} else {
new InterpretedProjection(expressions, inputSchema)
}
@@ -156,7 +156,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
log.debug(
s"Creating MutableProj: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled")
if(codegenEnabled) {
- GenerateMutableProjection(expressions, inputSchema)
+ GenerateMutableProjection.generate(expressions, inputSchema)
} else {
() => new InterpretedMutableProjection(expressions, inputSchema)
}
@@ -166,15 +166,15 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
protected def newPredicate(
expression: Expression, inputSchema: Seq[Attribute]): (Row) => Boolean = {
if (codegenEnabled) {
- GeneratePredicate(expression, inputSchema)
+ GeneratePredicate.generate(expression, inputSchema)
} else {
- InterpretedPredicate(expression, inputSchema)
+ InterpretedPredicate.create(expression, inputSchema)
}
}
protected def newOrdering(order: Seq[SortOrder], inputSchema: Seq[Attribute]): Ordering[Row] = {
if (codegenEnabled) {
- GenerateOrdering(order, inputSchema)
+ GenerateOrdering.generate(order, inputSchema)
} else {
new RowOrdering(order, inputSchema)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala
new file mode 100644
index 0000000000000..fe7607c6ac340
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.expressions
+
+import org.apache.spark.TaskContext
+import org.apache.spark.sql.catalyst.expressions.{Row, Expression}
+import org.apache.spark.sql.catalyst.trees
+import org.apache.spark.sql.types.{IntegerType, DataType}
+
+
+/**
+ * Expression that returns the current partition id of the Spark task.
+ */
+case object SparkPartitionID extends Expression with trees.LeafNode[Expression] {
+ self: Product =>
+
+ override type EvaluatedType = Int
+
+ override def nullable: Boolean = false
+
+ override def dataType: DataType = IntegerType
+
+ override def eval(input: Row): Int = TaskContext.get().partitionId()
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/package.scala
new file mode 100644
index 0000000000000..568b7ac2c5987
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/package.scala
@@ -0,0 +1,23 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+/**
+ * Package containing expressions that are specific to Spark runtime.
+ */
+package object expressions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
index 83b1a83765153..56200f6b8c8a9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
@@ -59,7 +59,7 @@ case class BroadcastNestedLoopJoin(
}
@transient private lazy val boundCondition =
- InterpretedPredicate(
+ InterpretedPredicate.create(
condition
.map(c => BindReferences.bindReference(c, left.output ++ right.output))
.getOrElse(Literal(true)))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala
index 1fa7e7bd0406c..e06f63f94b78b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala
@@ -45,7 +45,7 @@ case class LeftSemiJoinBNL(
override def right: SparkPlan = broadcast
@transient private lazy val boundCondition =
- InterpretedPredicate(
+ InterpretedPredicate.create(
condition
.map(c => BindReferences.bindReference(c, left.output ++ right.output))
.getOrElse(Literal(true)))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index ff91e1d74bc2c..9738fd4f93bad 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -276,6 +276,13 @@ object functions {
// Non-aggregate functions
//////////////////////////////////////////////////////////////////////////////////////////////
+ /**
+ * Computes the absolute value.
+ *
+ * @group normal_funcs
+ */
+ def abs(e: Column): Column = Abs(e.expr)
+
/**
* Returns the first column that is not null.
* {{{
@@ -287,6 +294,13 @@ object functions {
@scala.annotation.varargs
def coalesce(e: Column*): Column = Coalesce(e.map(_.expr))
+ /**
+ * Converts a string exprsesion to lower case.
+ *
+ * @group normal_funcs
+ */
+ def lower(e: Column): Column = Lower(e.expr)
+
/**
* Unary minus, i.e. negate the expression.
* {{{
@@ -317,18 +331,13 @@ object functions {
def not(e: Column): Column = !e
/**
- * Converts a string expression to upper case.
+ * Partition ID of the Spark task.
*
- * @group normal_funcs
- */
- def upper(e: Column): Column = Upper(e.expr)
-
- /**
- * Converts a string exprsesion to lower case.
+ * Note that this is indeterministic because it depends on data partitioning and task scheduling.
*
* @group normal_funcs
*/
- def lower(e: Column): Column = Lower(e.expr)
+ def sparkPartitionId(): Column = execution.expressions.SparkPartitionID
/**
* Computes the square root of the specified float value.
@@ -338,11 +347,11 @@ object functions {
def sqrt(e: Column): Column = Sqrt(e.expr)
/**
- * Computes the absolutle value.
+ * Converts a string expression to upper case.
*
* @group normal_funcs
*/
- def abs(e: Column): Column = Abs(e.expr)
+ def upper(e: Column): Column = Upper(e.expr)
//////////////////////////////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////////////////////////////
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala
index b9022fcd9e3ad..f326510042122 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala
@@ -26,6 +26,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions.{Row, SpecificMutableRow}
import org.apache.spark.sql.types._
import org.apache.spark.sql.sources._
+import org.apache.spark.util.Utils
private[sql] object JDBCRDD extends Logging {
/**
@@ -60,6 +61,7 @@ private[sql] object JDBCRDD extends Logging {
case java.sql.Types.NCLOB => StringType
case java.sql.Types.NULL => null
case java.sql.Types.NUMERIC => DecimalType.Unlimited
+ case java.sql.Types.NVARCHAR => StringType
case java.sql.Types.OTHER => null
case java.sql.Types.REAL => DoubleType
case java.sql.Types.REF => StringType
@@ -151,7 +153,7 @@ private[sql] object JDBCRDD extends Logging {
def getConnector(driver: String, url: String, properties: Properties): () => Connection = {
() => {
try {
- if (driver != null) Class.forName(driver)
+ if (driver != null) Utils.getContextOrSparkClassLoader.loadClass(driver)
} catch {
case e: ClassNotFoundException => {
logWarning(s"Couldn't find class $driver", e);
@@ -349,8 +351,8 @@ private[sql] class JDBCRDD(
val pos = i + 1
conversions(i) match {
case BooleanConversion => mutableRow.setBoolean(i, rs.getBoolean(pos))
- // TODO(davies): convert Date into Int
- case DateConversion => mutableRow.update(i, rs.getDate(pos))
+ case DateConversion =>
+ mutableRow.update(i, DateUtils.fromJavaDate(rs.getDate(pos)))
case DecimalConversion => mutableRow.update(i, rs.getBigDecimal(pos))
case DoubleConversion => mutableRow.setDouble(i, rs.getDouble(pos))
case FloatConversion => mutableRow.setFloat(i, rs.getFloat(pos))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
index 29de7401dda71..6e94e7056eb0b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
@@ -183,7 +183,7 @@ private[sql] object JsonRDD extends Logging {
private def typeOfPrimitiveValue: PartialFunction[Any, DataType] = {
// For Integer values, use LongType by default.
val useLongType: PartialFunction[Any, DataType] = {
- case value: IntegerType.JvmType => LongType
+ case value: IntegerType.InternalType => LongType
}
useLongType orElse ScalaReflection.typeOfObject orElse {
@@ -411,11 +411,11 @@ private[sql] object JsonRDD extends Logging {
desiredType match {
case StringType => UTF8String(toString(value))
case _ if value == null || value == "" => null // guard the non string type
- case IntegerType => value.asInstanceOf[IntegerType.JvmType]
+ case IntegerType => value.asInstanceOf[IntegerType.InternalType]
case LongType => toLong(value)
case DoubleType => toDouble(value)
case DecimalType() => toDecimal(value)
- case BooleanType => value.asInstanceOf[BooleanType.JvmType]
+ case BooleanType => value.asInstanceOf[BooleanType.InternalType]
case NullType => null
case ArrayType(elementType, _) =>
value.asInstanceOf[Seq[Any]].map(enforceCorrectType(_, elementType))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala
index bc108e37dfb0f..36cb5e03bbca7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala
@@ -90,7 +90,7 @@ private[sql] object CatalystConverter {
createConverter(field.copy(dataType = udt.sqlType), fieldIndex, parent)
}
// For native JVM types we use a converter with native arrays
- case ArrayType(elementType: NativeType, false) => {
+ case ArrayType(elementType: AtomicType, false) => {
new CatalystNativeArrayConverter(elementType, fieldIndex, parent)
}
// This is for other types of arrays, including those with nested fields
@@ -118,19 +118,19 @@ private[sql] object CatalystConverter {
case ShortType => {
new CatalystPrimitiveConverter(parent, fieldIndex) {
override def addInt(value: Int): Unit =
- parent.updateShort(fieldIndex, value.asInstanceOf[ShortType.JvmType])
+ parent.updateShort(fieldIndex, value.asInstanceOf[ShortType.InternalType])
}
}
case ByteType => {
new CatalystPrimitiveConverter(parent, fieldIndex) {
override def addInt(value: Int): Unit =
- parent.updateByte(fieldIndex, value.asInstanceOf[ByteType.JvmType])
+ parent.updateByte(fieldIndex, value.asInstanceOf[ByteType.InternalType])
}
}
case DateType => {
new CatalystPrimitiveConverter(parent, fieldIndex) {
override def addInt(value: Int): Unit =
- parent.updateDate(fieldIndex, value.asInstanceOf[DateType.JvmType])
+ parent.updateDate(fieldIndex, value.asInstanceOf[DateType.InternalType])
}
}
case d: DecimalType => {
@@ -146,7 +146,8 @@ private[sql] object CatalystConverter {
}
}
// All other primitive types use the default converter
- case ctype: PrimitiveType => { // note: need the type tag here!
+ case ctype: DataType if ParquetTypesConverter.isPrimitiveType(ctype) => {
+ // note: need the type tag here!
new CatalystPrimitiveConverter(parent, fieldIndex)
}
case _ => throw new RuntimeException(
@@ -324,9 +325,9 @@ private[parquet] class CatalystGroupConverter(
override def start(): Unit = {
current = ArrayBuffer.fill(size)(null)
- converters.foreach {
- converter => if (!converter.isPrimitive) {
- converter.asInstanceOf[CatalystConverter].clearBuffer
+ converters.foreach { converter =>
+ if (!converter.isPrimitive) {
+ converter.asInstanceOf[CatalystConverter].clearBuffer()
}
}
}
@@ -612,7 +613,7 @@ private[parquet] class CatalystArrayConverter(
override def start(): Unit = {
if (!converter.isPrimitive) {
- converter.asInstanceOf[CatalystConverter].clearBuffer
+ converter.asInstanceOf[CatalystConverter].clearBuffer()
}
}
@@ -636,13 +637,13 @@ private[parquet] class CatalystArrayConverter(
* @param capacity The (initial) capacity of the buffer
*/
private[parquet] class CatalystNativeArrayConverter(
- val elementType: NativeType,
+ val elementType: AtomicType,
val index: Int,
protected[parquet] val parent: CatalystConverter,
protected[parquet] var capacity: Int = CatalystArrayConverter.INITIAL_ARRAY_SIZE)
extends CatalystConverter {
- type NativeType = elementType.JvmType
+ type NativeType = elementType.InternalType
private var buffer: Array[NativeType] = elementType.classTag.newArray(capacity)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
index 1c868da23e060..a938b77578686 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
@@ -268,7 +268,7 @@ private[sql] case class InsertIntoParquetTable(
val job = new Job(sqlContext.sparkContext.hadoopConfiguration)
val writeSupport =
- if (child.output.map(_.dataType).forall(_.isPrimitive)) {
+ if (child.output.map(_.dataType).forall(ParquetTypesConverter.isPrimitiveType)) {
log.debug("Initializing MutableRowWriteSupport")
classOf[org.apache.spark.sql.parquet.MutableRowWriteSupport]
} else {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala
index e05a4c20b0d41..c45c431438efc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala
@@ -189,7 +189,7 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging {
case t @ StructType(_) => writeStruct(
t,
value.asInstanceOf[CatalystConverter.StructScalaType[_]])
- case _ => writePrimitive(schema.asInstanceOf[NativeType], value)
+ case _ => writePrimitive(schema.asInstanceOf[AtomicType], value)
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala
index 60e1bec4db8e5..1dc819b5d7b9b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala
@@ -48,8 +48,10 @@ private[parquet] case class ParquetTypeInfo(
length: Option[Int] = None)
private[parquet] object ParquetTypesConverter extends Logging {
- def isPrimitiveType(ctype: DataType): Boolean =
- classOf[PrimitiveType] isAssignableFrom ctype.getClass
+ def isPrimitiveType(ctype: DataType): Boolean = ctype match {
+ case _: NumericType | BooleanType | StringType | BinaryType => true
+ case _: DataType => false
+ }
def toPrimitiveDataType(
parquetType: ParquetPrimitiveType,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala
index af7b3c81ae7b2..85e60733bc57a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala
@@ -611,7 +611,7 @@ private[sql] case class ParquetRelation2(
val rawPredicate =
partitionPruningPredicates.reduceOption(expressions.And).getOrElse(Literal(true))
- val boundPredicate = InterpretedPredicate(rawPredicate transform {
+ val boundPredicate = InterpretedPredicate.create(rawPredicate transform {
case a: AttributeReference =>
val index = partitionColumns.indexWhere(a.name == _.name)
BoundReference(index, partitionColumns(index).dataType, nullable = true)
@@ -634,12 +634,13 @@ private[sql] case class ParquetRelation2(
// before calling execute().
val job = new Job(sqlContext.sparkContext.hadoopConfiguration)
- val writeSupport = if (parquetSchema.map(_.dataType).forall(_.isPrimitive)) {
- log.debug("Initializing MutableRowWriteSupport")
- classOf[MutableRowWriteSupport]
- } else {
- classOf[RowWriteSupport]
- }
+ val writeSupport =
+ if (parquetSchema.map(_.dataType).forall(ParquetTypesConverter.isPrimitiveType)) {
+ log.debug("Initializing MutableRowWriteSupport")
+ classOf[MutableRowWriteSupport]
+ } else {
+ classOf[RowWriteSupport]
+ }
ParquetOutputFormat.setWriteSupportClass(job, writeSupport)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
index 78d494184e759..e7a0685e013d8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
@@ -38,9 +38,9 @@ private[sql] class DDLParser(
parseQuery: String => LogicalPlan)
extends AbstractSparkSQLParser with DataTypeParser with Logging {
- def apply(input: String, exceptionOnError: Boolean): Option[LogicalPlan] = {
+ def parse(input: String, exceptionOnError: Boolean): Option[LogicalPlan] = {
try {
- Some(apply(input))
+ Some(parse(input))
} catch {
case ddlException: DDLException => throw ddlException
case _ if !exceptionOnError => None
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
index fc3ed4a708d46..e02c84872c628 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
@@ -162,7 +162,7 @@ public void testCreateDataFrameFromJavaBeans() {
Buffer outputBuffer = (Buffer) first.getJavaMap(2).get("hello");
Assert.assertArrayEquals(
bean.getC().get("hello"),
- Ints.toArray(JavaConversions.asJavaList(outputBuffer)));
+ Ints.toArray(JavaConversions.bufferAsJavaList(outputBuffer)));
Seq d = first.getAs(3);
Assert.assertEquals(bean.getD().size(), d.length());
for (int i = 0; i < d.length(); i++) {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
index 01e3b8671071e..0772e5e187425 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
@@ -300,19 +300,26 @@ class CachedTableSuite extends QueryTest {
}
test("Clear accumulators when uncacheTable to prevent memory leaking") {
- val accsSize = Accumulators.originals.size
-
sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1")
sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2")
- cacheTable("t1")
- cacheTable("t2")
+
+ Accumulators.synchronized {
+ val accsSize = Accumulators.originals.size
+ cacheTable("t1")
+ cacheTable("t2")
+ assert((accsSize + 2) == Accumulators.originals.size)
+ }
+
sql("SELECT * FROM t1").count()
sql("SELECT * FROM t2").count()
sql("SELECT * FROM t1").count()
sql("SELECT * FROM t2").count()
- uncacheTable("t1")
- uncacheTable("t2")
- assert(accsSize >= Accumulators.originals.size)
+ Accumulators.synchronized {
+ val accsSize = Accumulators.originals.size
+ uncacheTable("t1")
+ uncacheTable("t2")
+ assert((accsSize - 2) == Accumulators.originals.size)
+ }
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
index bc8fae100db6a..904073b8cb2aa 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
@@ -310,6 +310,14 @@ class ColumnExpressionSuite extends QueryTest {
)
}
+ test("sparkPartitionId") {
+ val df = TestSQLContext.sparkContext.parallelize(1 to 1, 1).map(i => (i, i)).toDF("a", "b")
+ checkAnswer(
+ df.select(sparkPartitionId()),
+ Row(0)
+ )
+ }
+
test("lift alias out of cast") {
compareExpressions(
col("1234").as("name").cast("int").expr,
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index b9b6a400ae195..5ec06d448e50f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -109,15 +109,6 @@ class DataFrameSuite extends QueryTest {
assert(testData.head(2).head.schema === testData.schema)
}
- test("self join") {
- val df1 = testData.select(testData("key")).as('df1)
- val df2 = testData.select(testData("key")).as('df2)
-
- checkAnswer(
- df1.join(df2, $"df1.key" === $"df2.key"),
- sql("SELECT a.key, b.key FROM testData a JOIN testData b ON a.key = b.key").collect().toSeq)
- }
-
test("simple explode") {
val df = Seq(Tuple1("a b c"), Tuple1("d e")).toDF("words")
@@ -127,8 +118,35 @@ class DataFrameSuite extends QueryTest {
)
}
- test("self join with aliases") {
- val df = Seq(1,2,3).map(i => (i, i.toString)).toDF("int", "str")
+ test("join - join using") {
+ val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str")
+ val df2 = Seq(1, 2, 3).map(i => (i, (i + 1).toString)).toDF("int", "str")
+
+ checkAnswer(
+ df.join(df2, "int"),
+ Row(1, "1", "2") :: Row(2, "2", "3") :: Row(3, "3", "4") :: Nil)
+ }
+
+ test("join - join using self join") {
+ val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str")
+
+ // self join
+ checkAnswer(
+ df.join(df, "int"),
+ Row(1, "1", "1") :: Row(2, "2", "2") :: Row(3, "3", "3") :: Nil)
+ }
+
+ test("join - self join") {
+ val df1 = testData.select(testData("key")).as('df1)
+ val df2 = testData.select(testData("key")).as('df2)
+
+ checkAnswer(
+ df1.join(df2, $"df1.key" === $"df2.key"),
+ sql("SELECT a.key, b.key FROM testData a JOIN testData b ON a.key = b.key").collect().toSeq)
+ }
+
+ test("join - using aliases after self join") {
+ val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str")
checkAnswer(
df.as('x).join(df.as('y), $"x.str" === $"y.str").groupBy("x.str").count(),
Row("1", 1) :: Row("2", 1) :: Row("3", 1) :: Nil)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
index fec487f1d2c82..7cefcf44061ce 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
@@ -34,7 +34,7 @@ class ColumnStatsSuite extends FunSuite {
testColumnStats(classOf[DateColumnStats], DATE, Row(Int.MaxValue, Int.MinValue, 0))
testColumnStats(classOf[TimestampColumnStats], TIMESTAMP, Row(null, null, 0))
- def testColumnStats[T <: NativeType, U <: ColumnStats](
+ def testColumnStats[T <: AtomicType, U <: ColumnStats](
columnStatsClass: Class[U],
columnType: NativeColumnType[T],
initialStatistics: Row): Unit = {
@@ -55,8 +55,8 @@ class ColumnStatsSuite extends FunSuite {
val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1))
rows.foreach(columnStats.gatherStats(_, 0))
- val values = rows.take(10).map(_(0).asInstanceOf[T#JvmType])
- val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#JvmType]]
+ val values = rows.take(10).map(_(0).asInstanceOf[T#InternalType])
+ val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#InternalType]]
val stats = columnStats.collectedStatistics
assertResult(values.min(ordering), "Wrong lower bound")(stats(0))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala
index b48bed1871c50..1e105e259dce7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala
@@ -196,12 +196,12 @@ class ColumnTypeSuite extends FunSuite with Logging {
}
}
- def testNativeColumnType[T <: NativeType](
+ def testNativeColumnType[T <: AtomicType](
columnType: NativeColumnType[T],
- putter: (ByteBuffer, T#JvmType) => Unit,
- getter: (ByteBuffer) => T#JvmType): Unit = {
+ putter: (ByteBuffer, T#InternalType) => Unit,
+ getter: (ByteBuffer) => T#InternalType): Unit = {
- testColumnType[T, T#JvmType](columnType, putter, getter)
+ testColumnType[T, T#InternalType](columnType, putter, getter)
}
def testColumnType[T <: DataType, JvmType](
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala
index f76314b9dab5e..75d993e563e06 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala
@@ -24,7 +24,7 @@ import scala.util.Random
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
-import org.apache.spark.sql.types.{UTF8String, DataType, Decimal, NativeType}
+import org.apache.spark.sql.types.{UTF8String, DataType, Decimal, AtomicType}
object ColumnarTestUtils {
def makeNullRow(length: Int): GenericMutableRow = {
@@ -91,9 +91,9 @@ object ColumnarTestUtils {
row
}
- def makeUniqueValuesAndSingleValueRows[T <: NativeType](
+ def makeUniqueValuesAndSingleValueRows[T <: AtomicType](
columnType: NativeColumnType[T],
- count: Int): (Seq[T#JvmType], Seq[GenericMutableRow]) = {
+ count: Int): (Seq[T#InternalType], Seq[GenericMutableRow]) = {
val values = makeUniqueRandomValues(columnType, count)
val rows = values.map { value =>
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala
index c82d9799359c7..64b70552eb047 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala
@@ -24,14 +24,14 @@ import org.scalatest.FunSuite
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
import org.apache.spark.sql.columnar._
import org.apache.spark.sql.columnar.ColumnarTestUtils._
-import org.apache.spark.sql.types.NativeType
+import org.apache.spark.sql.types.AtomicType
class DictionaryEncodingSuite extends FunSuite {
testDictionaryEncoding(new IntColumnStats, INT)
testDictionaryEncoding(new LongColumnStats, LONG)
testDictionaryEncoding(new StringColumnStats, STRING)
- def testDictionaryEncoding[T <: NativeType](
+ def testDictionaryEncoding[T <: AtomicType](
columnStats: ColumnStats,
columnType: NativeColumnType[T]) {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala
index 88011631ee4e3..bfd99f143bedc 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala
@@ -33,7 +33,7 @@ class IntegralDeltaSuite extends FunSuite {
columnType: NativeColumnType[I],
scheme: CompressionScheme) {
- def skeleton(input: Seq[I#JvmType]) {
+ def skeleton(input: Seq[I#InternalType]) {
// -------------
// Tests encoder
// -------------
@@ -120,13 +120,13 @@ class IntegralDeltaSuite extends FunSuite {
case LONG => Seq(2: Long, 1: Long, 2: Long, 130: Long)
}
- skeleton(input.map(_.asInstanceOf[I#JvmType]))
+ skeleton(input.map(_.asInstanceOf[I#InternalType]))
}
test(s"$scheme: long random series") {
// Have to workaround with `Any` since no `ClassTag[I#JvmType]` available here.
val input = Array.fill[Any](10000)(makeRandomValue(columnType))
- skeleton(input.map(_.asInstanceOf[I#JvmType]))
+ skeleton(input.map(_.asInstanceOf[I#InternalType]))
}
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala
index 08df1db375097..fde7a4595be0e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala
@@ -22,7 +22,7 @@ import org.scalatest.FunSuite
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
import org.apache.spark.sql.columnar._
import org.apache.spark.sql.columnar.ColumnarTestUtils._
-import org.apache.spark.sql.types.NativeType
+import org.apache.spark.sql.types.AtomicType
class RunLengthEncodingSuite extends FunSuite {
testRunLengthEncoding(new NoopColumnStats, BOOLEAN)
@@ -32,7 +32,7 @@ class RunLengthEncodingSuite extends FunSuite {
testRunLengthEncoding(new LongColumnStats, LONG)
testRunLengthEncoding(new StringColumnStats, STRING)
- def testRunLengthEncoding[T <: NativeType](
+ def testRunLengthEncoding[T <: AtomicType](
columnStats: ColumnStats,
columnType: NativeColumnType[T]) {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala
index fc8ff3b41d0e6..5268dfe0aa03e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala
@@ -18,9 +18,9 @@
package org.apache.spark.sql.columnar.compression
import org.apache.spark.sql.columnar._
-import org.apache.spark.sql.types.NativeType
+import org.apache.spark.sql.types.AtomicType
-class TestCompressibleColumnBuilder[T <: NativeType](
+class TestCompressibleColumnBuilder[T <: AtomicType](
override val columnStats: ColumnStats,
override val columnType: NativeColumnType[T],
override val schemes: Seq[CompressionScheme])
@@ -32,7 +32,7 @@ class TestCompressibleColumnBuilder[T <: NativeType](
}
object TestCompressibleColumnBuilder {
- def apply[T <: NativeType](
+ def apply[T <: AtomicType](
columnStats: ColumnStats,
columnType: NativeColumnType[T],
scheme: CompressionScheme): TestCompressibleColumnBuilder[T] = {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
index 3596b183d4328..db096af4535a9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
@@ -249,6 +249,13 @@ class JDBCSuite extends FunSuite with BeforeAndAfter {
assert(rows(0).getAs[java.sql.Timestamp](2).getNanos === 543543543)
}
+ test("test DATE types") {
+ val rows = TestSQLContext.jdbc(urlWithUserAndPass, "TEST.TIMETYPES").collect()
+ val cachedRows = TestSQLContext.jdbc(urlWithUserAndPass, "TEST.TIMETYPES").cache().collect()
+ assert(rows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01"))
+ assert(cachedRows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01"))
+ }
+
test("H2 floating-point types") {
val rows = sql("SELECT * FROM flttypes").collect()
assert(rows(0).getDouble(0) === 1.00000000000000022) // Yes, I meant ==.
diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml
index 04440076a26a3..21dce8d8a565a 100644
--- a/sql/hive/pom.xml
+++ b/sql/hive/pom.xml
@@ -59,6 +59,11 @@
${hive.group}hive-exec
+
+ org.apache.httpcomponents
+ httpclient
+ ${commons.httpclient.version}
+ org.codehaus.jacksonjackson-mapper-asl
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
index c4a73b3004076..dd06b2620c5ee 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
@@ -93,7 +93,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
if (conf.dialect == "sql") {
super.sql(substituted)
} else if (conf.dialect == "hiveql") {
- val ddlPlan = ddlParserWithHiveQL(sqlText, exceptionOnError = false)
+ val ddlPlan = ddlParserWithHiveQL.parse(sqlText, exceptionOnError = false)
DataFrame(this, ddlPlan.getOrElse(HiveQl.parseSql(substituted)))
} else {
sys.error(s"Unsupported SQL dialect: ${conf.dialect}. Try 'sql' or 'hiveql'")
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
index f1c0bd92aa23d..4d222cf88e5e8 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
@@ -871,7 +871,7 @@ private[hive] case class MetastoreRelation
private[hive] object HiveMetastoreTypes {
- def toDataType(metastoreType: String): DataType = DataTypeParser(metastoreType)
+ def toDataType(metastoreType: String): DataType = DataTypeParser.parse(metastoreType)
def toMetastoreType(dt: DataType): String = dt match {
case ArrayType(elementType, _) => s"array<${toMetastoreType(elementType)}>"
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index 85061f22772dd..0ea6d57b816c6 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -144,7 +144,7 @@ private[hive] object HiveQl {
protected val hqlParser = {
val fallback = new ExtendedHiveQlParser
- new SparkSQLParser(fallback(_))
+ new SparkSQLParser(fallback.parse(_))
}
/**
@@ -240,7 +240,7 @@ private[hive] object HiveQl {
/** Returns a LogicalPlan for a given HiveQL string. */
- def parseSql(sql: String): LogicalPlan = hqlParser(sql)
+ def parseSql(sql: String): LogicalPlan = hqlParser.parse(sql)
val errorRegEx = "line (\\d+):(\\d+) (.*)".r
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
index a6f4fbe8aba06..be9249a8b1f44 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
@@ -119,9 +119,9 @@ private[hive] trait HiveStrategies {
val inputData = new GenericMutableRow(relation.partitionKeys.size)
val pruningCondition =
if (codegenEnabled) {
- GeneratePredicate(castedPredicate)
+ GeneratePredicate.generate(castedPredicate)
} else {
- InterpretedPredicate(castedPredicate)
+ InterpretedPredicate.create(castedPredicate)
}
val partitions = relation.hiveQlPartitions.filter { part =>
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
index cab0fdd35723a..3eddda3b28c66 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
@@ -145,20 +145,29 @@ case class ScriptTransformation(
val dataOutputStream = new DataOutputStream(outputStream)
val outputProjection = new InterpretedProjection(input, child.output)
- iter
- .map(outputProjection)
- .foreach { row =>
- if (inputSerde == null) {
- val data = row.mkString("", ioschema.inputRowFormatMap("TOK_TABLEROWFORMATFIELD"),
- ioschema.inputRowFormatMap("TOK_TABLEROWFORMATLINES")).getBytes("utf-8")
-
- outputStream.write(data)
- } else {
- val writable = inputSerde.serialize(row.asInstanceOf[GenericRow].values, inputSoi)
- prepareWritable(writable).write(dataOutputStream)
+ // Put the write(output to the pipeline) into a single thread
+ // and keep the collector as remain in the main thread.
+ // otherwise it will causes deadlock if the data size greater than
+ // the pipeline / buffer capacity.
+ new Thread(new Runnable() {
+ override def run(): Unit = {
+ iter
+ .map(outputProjection)
+ .foreach { row =>
+ if (inputSerde == null) {
+ val data = row.mkString("", ioschema.inputRowFormatMap("TOK_TABLEROWFORMATFIELD"),
+ ioschema.inputRowFormatMap("TOK_TABLEROWFORMATLINES")).getBytes("utf-8")
+
+ outputStream.write(data)
+ } else {
+ val writable = inputSerde.serialize(row.asInstanceOf[GenericRow].values, inputSoi)
+ prepareWritable(writable).write(dataOutputStream)
+ }
}
+ outputStream.close()
}
- outputStream.close()
+ }).start()
+
iterator
}
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
index 6570fa1043900..9f17bca083d13 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
@@ -185,7 +185,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) {
logDebug(s"Query references test tables: ${referencedTestTables.mkString(", ")}")
referencedTestTables.foreach(loadTestTable)
// Proceed with analysis.
- analyzer(logical)
+ analyzer.execute(logical)
}
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
index 47b4cb9ca61ff..4f8d0ac0e7656 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
@@ -561,4 +561,12 @@ class SQLQuerySuite extends QueryTest {
sql("select d from dn union all select d * 2 from dn")
.queryExecution.analyzed
}
+
+ test("test script transform") {
+ val data = (1 to 100000).map { i => (i, i, i) }
+ data.toDF("d1", "d2", "d3").registerTempTable("script_trans")
+ assert(100000 ===
+ sql("SELECT TRANSFORM (d1, d2, d3) USING 'cat' AS (a,b,c) FROM script_trans")
+ .queryExecution.toRdd.count())
+ }
}
diff --git a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala
index d331c210e8939..dbc5e029e2047 100644
--- a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala
+++ b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala
@@ -19,11 +19,15 @@ package org.apache.spark.sql.hive
import java.rmi.server.UID
import java.util.{Properties, ArrayList => JArrayList}
+import java.io.{OutputStream, InputStream}
import scala.collection.JavaConversions._
import scala.language.implicitConversions
+import scala.reflect.ClassTag
import com.esotericsoftware.kryo.Kryo
+import com.esotericsoftware.kryo.io.Input
+import com.esotericsoftware.kryo.io.Output
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.hadoop.hive.common.StatsSetupConst
@@ -46,6 +50,7 @@ import org.apache.hadoop.{io => hadoopIo}
import org.apache.spark.Logging
import org.apache.spark.sql.types.{Decimal, DecimalType, UTF8String}
+import org.apache.spark.util.Utils._
/**
* This class provides the UDF creation and also the UDF instance serialization and
@@ -61,39 +66,34 @@ private[hive] case class HiveFunctionWrapper(var functionClassName: String)
// for Serialization
def this() = this(null)
- import org.apache.spark.util.Utils._
-
@transient
- private val methodDeSerialize = {
- val method = classOf[Utilities].getDeclaredMethod(
- "deserializeObjectByKryo",
- classOf[Kryo],
- classOf[java.io.InputStream],
- classOf[Class[_]])
- method.setAccessible(true)
-
- method
+ def deserializeObjectByKryo[T: ClassTag](
+ kryo: Kryo,
+ in: InputStream,
+ clazz: Class[_]): T = {
+ val inp = new Input(in)
+ val t: T = kryo.readObject(inp,clazz).asInstanceOf[T]
+ inp.close()
+ t
}
@transient
- private val methodSerialize = {
- val method = classOf[Utilities].getDeclaredMethod(
- "serializeObjectByKryo",
- classOf[Kryo],
- classOf[Object],
- classOf[java.io.OutputStream])
- method.setAccessible(true)
-
- method
+ def serializeObjectByKryo(
+ kryo: Kryo,
+ plan: Object,
+ out: OutputStream ) {
+ val output: Output = new Output(out)
+ kryo.writeObject(output, plan)
+ output.close()
}
def deserializePlan[UDFType](is: java.io.InputStream, clazz: Class[_]): UDFType = {
- methodDeSerialize.invoke(null, Utilities.runtimeSerializationKryo.get(), is, clazz)
+ deserializeObjectByKryo(Utilities.runtimeSerializationKryo.get(), is, clazz)
.asInstanceOf[UDFType]
}
def serializePlan(function: AnyRef, out: java.io.OutputStream): Unit = {
- methodSerialize.invoke(null, Utilities.runtimeSerializationKryo.get(), function, out)
+ serializeObjectByKryo(Utilities.runtimeSerializationKryo.get(), function, out)
}
private var instance: AnyRef = null
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala
index dcdc27d29c270..297bf04c0c25e 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala
@@ -28,7 +28,7 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.{Logging, SparkConf, SparkException}
import org.apache.spark.storage._
import org.apache.spark.streaming.util.{WriteAheadLogFileSegment, WriteAheadLogManager}
-import org.apache.spark.util.{Clock, SystemClock, Utils}
+import org.apache.spark.util.{ThreadUtils, Clock, SystemClock}
/** Trait that represents the metadata related to storage of blocks */
private[streaming] trait ReceivedBlockStoreResult {
@@ -150,7 +150,7 @@ private[streaming] class WriteAheadLogBasedBlockHandler(
// For processing futures used in parallel block storing into block manager and write ahead log
// # threads = 2, so that both writing to BM and WAL can proceed in parallel
implicit private val executionContext = ExecutionContext.fromExecutorService(
- Utils.newDaemonFixedThreadPool(2, this.getClass.getSimpleName))
+ ThreadUtils.newDaemonFixedThreadPool(2, this.getClass.getSimpleName))
/**
* This implementation stores the block into the block manager as well as a write ahead log.
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogManager.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogManager.scala
index 6bdfe45dc7f83..38a93cc3c9a1f 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogManager.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogManager.scala
@@ -25,7 +25,7 @@ import scala.language.postfixOps
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.spark.Logging
-import org.apache.spark.util.{Clock, SystemClock, Utils}
+import org.apache.spark.util.{ThreadUtils, Clock, SystemClock}
import WriteAheadLogManager._
/**
@@ -60,7 +60,7 @@ private[streaming] class WriteAheadLogManager(
if (callerName.nonEmpty) s" for $callerName" else ""
private val threadpoolName = s"WriteAheadLogManager $callerNameTag"
implicit private val executionContext = ExecutionContext.fromExecutorService(
- Utils.newDaemonFixedThreadPool(1, threadpoolName))
+ ThreadUtils.newDaemonSingleThreadExecutor(threadpoolName))
override protected val logName = s"WriteAheadLogManager $callerNameTag"
private var currentLogPath: Option[String] = None
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala
index e7aee6eadbfc7..b84129fd70dd4 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala
@@ -155,7 +155,7 @@ class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable {
assert(recordedData.toSet === generatedData.toSet)
}
- test("block generator throttling") {
+ ignore("block generator throttling") {
val blockGeneratorListener = new FakeBlockGeneratorListener
val blockIntervalMs = 100
val maxRate = 1001