diff --git a/assembly/pom.xml b/assembly/pom.xml index 22bbbc57d..b5e752c6c 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -79,6 +79,11 @@ spark-graphx_${scala.binary.version} ${project.version} + + org.apache.spark + spark-sql_${scala.binary.version} + ${project.version} + net.sf.py4j py4j diff --git a/bin/compute-classpath.sh b/bin/compute-classpath.sh index 278969655..0624117f4 100755 --- a/bin/compute-classpath.sh +++ b/bin/compute-classpath.sh @@ -33,23 +33,43 @@ fi # Build up classpath CLASSPATH="$SPARK_CLASSPATH:$FWDIR/conf" +# Support for interacting with Hive. Since hive pulls in a lot of dependencies that might break +# existing Spark applications, it is not included in the standard spark assembly. Instead, we only +# include it in the classpath if the user has explicitly requested it by running "sbt hive/assembly" +# Hopefully we will find a way to avoid uber-jars entirely and deploy only the needed packages in +# the future. +if [ -f "$FWDIR"/sql/hive/target/scala-$SCALA_VERSION/spark-hive-assembly-*.jar ]; then + echo "Hive assembly found, including hive support. If this isn't desired run sbt hive/clean." + + # Datanucleus jars do not work if only included in the uberjar as plugin.xml metadata is lost. + DATANUCLEUSJARS=$(JARS=("$FWDIR/lib_managed/jars"/datanucleus-*.jar); IFS=:; echo "${JARS[*]}") + CLASSPATH=$CLASSPATH:$DATANUCLEUSJARS + + ASSEMBLY_DIR="$FWDIR/sql/hive/target/scala-$SCALA_VERSION/" +else + ASSEMBLY_DIR="$FWDIR/assembly/target/scala-$SCALA_VERSION/" +fi + # First check if we have a dependencies jar. If so, include binary classes with the deps jar -if [ -f "$FWDIR"/assembly/target/scala-$SCALA_VERSION/spark-assembly*hadoop*-deps.jar ]; then +if [ -f "$ASSEMBLY_DIR"/spark-assembly*hadoop*-deps.jar ]; then CLASSPATH="$CLASSPATH:$FWDIR/core/target/scala-$SCALA_VERSION/classes" CLASSPATH="$CLASSPATH:$FWDIR/repl/target/scala-$SCALA_VERSION/classes" CLASSPATH="$CLASSPATH:$FWDIR/mllib/target/scala-$SCALA_VERSION/classes" CLASSPATH="$CLASSPATH:$FWDIR/bagel/target/scala-$SCALA_VERSION/classes" CLASSPATH="$CLASSPATH:$FWDIR/graphx/target/scala-$SCALA_VERSION/classes" CLASSPATH="$CLASSPATH:$FWDIR/streaming/target/scala-$SCALA_VERSION/classes" + CLASSPATH="$CLASSPATH:$FWDIR/sql/catalyst/target/scala-$SCALA_VERSION/classes" + CLASSPATH="$CLASSPATH:$FWDIR/sql/core/target/scala-$SCALA_VERSION/classes" + CLASSPATH="$CLASSPATH:$FWDIR/sql/hive/target/scala-$SCALA_VERSION/classes" - DEPS_ASSEMBLY_JAR=`ls "$FWDIR"/assembly/target/scala-$SCALA_VERSION/spark-assembly*hadoop*-deps.jar` + DEPS_ASSEMBLY_JAR=`ls "$ASSEMBLY_DIR"/spark*-assembly*hadoop*-deps.jar` CLASSPATH="$CLASSPATH:$DEPS_ASSEMBLY_JAR" else # Else use spark-assembly jar from either RELEASE or assembly directory if [ -f "$FWDIR/RELEASE" ]; then - ASSEMBLY_JAR=`ls "$FWDIR"/jars/spark-assembly*.jar` + ASSEMBLY_JAR=`ls "$FWDIR"/jars/spark*-assembly*.jar` else - ASSEMBLY_JAR=`ls "$FWDIR"/assembly/target/scala-$SCALA_VERSION/spark-assembly*hadoop*.jar` + ASSEMBLY_JAR=`ls "$ASSEMBLY_DIR"/spark*-assembly*hadoop*.jar` fi CLASSPATH="$CLASSPATH:$ASSEMBLY_JAR" fi @@ -62,6 +82,9 @@ if [[ $SPARK_TESTING == 1 ]]; then CLASSPATH="$CLASSPATH:$FWDIR/bagel/target/scala-$SCALA_VERSION/test-classes" CLASSPATH="$CLASSPATH:$FWDIR/graphx/target/scala-$SCALA_VERSION/test-classes" CLASSPATH="$CLASSPATH:$FWDIR/streaming/target/scala-$SCALA_VERSION/test-classes" + CLASSPATH="$CLASSPATH:$FWDIR/sql/catalyst/target/scala-$SCALA_VERSION/test-classes" + CLASSPATH="$CLASSPATH:$FWDIR/sql/core/target/scala-$SCALA_VERSION/test-classes" + CLASSPATH="$CLASSPATH:$FWDIR/sql/hive/target/scala-$SCALA_VERSION/test-classes" fi # Add hadoop conf dir if given -- otherwise FileSystem.*, etc fail ! diff --git a/dev/download-hive-tests.sh b/dev/download-hive-tests.sh new file mode 100755 index 000000000..6c412a849 --- /dev/null +++ b/dev/download-hive-tests.sh @@ -0,0 +1,4 @@ +#!/bin/sh + +wget -O hiveTests.tgz http://cs.berkeley.edu/~marmbrus/tmp/hiveTests.tgz +tar zxf hiveTests.tgz \ No newline at end of file diff --git a/dev/run-tests b/dev/run-tests index cf0b940c0..b62a25f42 100755 --- a/dev/run-tests +++ b/dev/run-tests @@ -21,6 +21,9 @@ FWDIR="$(cd `dirname $0`/..; pwd)" cd $FWDIR +# Download Hive Compatability Files +dev/download-hive-tests.sh + # Remove work directory rm -rf ./work diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html index 49fd78ca9..5d4dbb7a9 100755 --- a/docs/_layouts/global.html +++ b/docs/_layouts/global.html @@ -66,6 +66,7 @@
  • Spark in Python
  • Spark Streaming
  • +
  • Spark SQL
  • MLlib (Machine Learning)
  • Bagel (Pregel on Spark)
  • GraphX (Graph Processing)
  • @@ -79,6 +80,14 @@
  • Spark Core for Python
  • Spark Streaming
  • +
  • MLlib (Machine Learning)
  • Bagel (Pregel on Spark)
  • GraphX (Graph Processing)
  • diff --git a/docs/_plugins/copy_api_dirs.rb b/docs/_plugins/copy_api_dirs.rb index 44d64057f..2245bcbc7 100644 --- a/docs/_plugins/copy_api_dirs.rb +++ b/docs/_plugins/copy_api_dirs.rb @@ -22,6 +22,7 @@ # Build Scaladoc for Java/Scala core_projects = ["core", "examples", "repl", "bagel", "graphx", "streaming", "mllib"] external_projects = ["flume", "kafka", "mqtt", "twitter", "zeromq"] + sql_projects = ["catalyst", "core", "hive"] projects = core_projects + external_projects.map { |project_name| "external/" + project_name } @@ -49,6 +50,18 @@ cp_r(source + "/.", dest) end + sql_projects.each do |project_name| + source = "../sql/" + project_name + "/target/scala-2.10/api/" + dest = "api/sql/" + project_name + + puts "echo making directory " + dest + mkdir_p dest + + # From the rubydoc: cp_r('src', 'dest') makes src/dest, but this doesn't. + puts "cp -r " + source + "/. " + dest + cp_r(source + "/.", dest) + end + # Build Epydoc for Python puts "Moving to python directory and building epydoc." cd("../python") diff --git a/docs/index.md b/docs/index.md index 23311101e..7a13fa9a9 100644 --- a/docs/index.md +++ b/docs/index.md @@ -78,6 +78,7 @@ For this version of Spark (0.8.1) Hadoop 2.2.x (or newer) users will have to bui * [Java Programming Guide](java-programming-guide.html): using Spark from Java * [Python Programming Guide](python-programming-guide.html): using Spark from Python * [Spark Streaming](streaming-programming-guide.html): Spark's API for processing data streams +* [Spark SQL](sql-programming-guide.html): Support for running relational queries on Spark * [MLlib (Machine Learning)](mllib-guide.html): Spark's built-in machine learning library * [Bagel (Pregel on Spark)](bagel-programming-guide.html): simple graph processing model * [GraphX (Graphs on Spark)](graphx-programming-guide.html): Spark's new API for graphs diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md new file mode 100644 index 000000000..b6f21a5dc --- /dev/null +++ b/docs/sql-programming-guide.md @@ -0,0 +1,143 @@ +--- +layout: global +title: Spark SQL Programming Guide +--- +**Spark SQL is currently an Alpha component. Therefore, the APIs may be changed in future releases.** + +* This will become a table of contents (this text will be scraped). +{:toc} + +# Overview +Spark SQL allows relational queries expressed in SQL, HiveQL, or Scala to be executed using +Spark. At the core of this component is a new type of RDD, +[SchemaRDD](api/sql/core/index.html#org.apache.spark.sql.SchemaRDD). SchemaRDDs are composed +[Row](api/sql/catalyst/index.html#org.apache.spark.sql.catalyst.expressions.Row) objects along with +a schema that describes the data types of each column in the row. A SchemaRDD is similar to a table +in a traditional relational database. A SchemaRDD can be created from an existing RDD, parquet +file, or by running HiveQL against data stored in [Apache Hive](http://hive.apache.org/). + +**All of the examples on this page use sample data included in the Spark distribution and can be run in the spark-shell.** + +*************************************************************************************************** + +# Getting Started + +The entry point into all relational functionallity in Spark is the +[SQLContext](api/sql/core/index.html#org.apache.spark.sql.SQLContext) class, or one of its +decendents. To create a basic SQLContext, all you need is a SparkContext. + +{% highlight scala %} +val sc: SparkContext // An existing SparkContext. +val sqlContext = new org.apache.spark.sql.SQLContext(sc) + +// Importing the SQL context gives access to all the public SQL functions and implicit conversions. +import sqlContext._ +{% endhighlight %} + +## Running SQL on RDDs +One type of table that is supported by Spark SQL is an RDD of Scala case classetees. The case class +defines the schema of the table. The names of the arguments to the case class are read using +reflection and become the names of the columns. Case classes can also be nested or contain complex +types such as Sequences or Arrays. This RDD can be implicitly converted to a SchemaRDD and then be +registered as a table. Tables can used in subsequent SQL statements. + +{% highlight scala %} +val sqlContext = new org.apache.spark.sql.SQLContext(sc) +import sqlContext._ + +// Define the schema using a case class. +case class Person(name: String, age: Int) + +// Create an RDD of Person objects and register it as a table. +val people = sc.textFile("examples/src/main/resources/people.txt").map(_.split(",")).map(p => Person(p(0), p(1).trim.toInt)) +people.registerAsTable("people") + +// SQL statements can be run by using the sql methods provided by sqlContext. +val teenagers = sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") + +// The results of SQL queries are SchemaRDDs and support all the normal RDD operations. +// The columns of a row in the result can be accessed by ordinal. +teenagers.map(t => "Name: " + t(0)).collect().foreach(println) +{% endhighlight %} + +**Note that Spark SQL currently uses a very basic SQL parser, and the keywords are case sensitive.** +Users that want a more complete dialect of SQL should look at the HiveQL support provided by +`HiveContext`. + +## Using Parquet + +Parquet is a columnar format that is supported by many other data processing systems. Spark SQL +provides support for both reading and writing parquet files that automatically preserves the schema +of the original data. Using the data from the above example: + +{% highlight scala %} +val sqlContext = new org.apache.spark.sql.SQLContext(sc) +import sqlContext._ + +val people: RDD[Person] // An RDD of case class objects, from the previous example. + +// The RDD is implicitly converted to a SchemaRDD, allowing it to be stored using parquet. +people.saveAsParquetFile("people.parquet") + +// Read in the parquet file created above. Parquet files are self-describing so the schema is preserved. +// The result of loading a parquet file is also a SchemaRDD. +val parquetFile = sqlContext.parquetFile("people.parquet") + +//Parquet files can also be registered as tables and then used in SQL statements. +parquetFile.registerAsTable("parquetFile") +val teenagers = sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19") +teenagers.collect().foreach(println) +{% endhighlight %} + +## Writing Language-Integrated Relational Queries + +Spark SQL also supports a domain specific language for writing queries. Once again, +using the data from the above examples: + +{% highlight scala %} +val sqlContext = new org.apache.spark.sql.SQLContext(sc) +import sqlContext._ +val people: RDD[Person] // An RDD of case class objects, from the first example. + +// The following is the same as 'SELECT name FROM people WHERE age >= 10 AND age <= 19' +val teenagers = people.where('age >= 10).where('age <= 19).select('name) +{% endhighlight %} + +The DSL uses Scala symbols to represent columns in the underlying table, which are identifiers +prefixed with a tick (`'`). Implicit conversions turn these symbols into expressions that are +evaluated by the SQL execution engine. A full list of the functions supported can be found in the +[ScalaDoc](api/sql/core/index.html#org.apache.spark.sql.SchemaRDD). + + + +# Hive Support + +Spark SQL also supports reading and writing data stored in [Apache Hive](http://hive.apache.org/). +However, since Hive has a large number of dependencies, it is not included in the default Spark assembly. +In order to use Hive you must first run '`sbt/sbt hive/assembly`'. This command builds a new assembly +jar that includes Hive. When this jar is present, Spark will use the Hive +assembly instead of the normal Spark assembly. Note that this Hive assembly jar must also be present +on all of the worker nodes, as they will need access to the Hive serialization and deserialization libraries +(SerDes) in order to acccess data stored in Hive. + +Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`. + +When working with Hive one must construct a `HiveContext`, which inherits from `SQLContext`, and +adds support for finding tables in in the MetaStore and writing queries using HiveQL. Users who do +not have an existing Hive deployment can also experiment with the `LocalHiveContext`, +which is similar to `HiveContext`, but creates a local copy of the `metastore` and `warehouse` +automatically. + +{% highlight scala %} +val sc: SparkContext // An existing SparkContext. +val hiveContext = new org.apache.spark.sql.hive.HiveContext(sc) + +// Importing the SQL context gives access to all the public SQL functions and implicit conversions. +import hiveContext._ + +sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") +sql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src") + +// Queries are expressed in HiveQL +sql("SELECT key, value FROM src").collect().foreach(println) +{% endhighlight %} \ No newline at end of file diff --git a/examples/pom.xml b/examples/pom.xml index 382a38d94..a5569ff5e 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -70,6 +70,12 @@ ${project.version} provided
    + + org.apache.spark + spark-hive_${scala.binary.version} + ${project.version} + provided + org.apache.spark spark-graphx_${scala.binary.version} diff --git a/examples/src/main/resources/kv1.txt b/examples/src/main/resources/kv1.txt new file mode 100644 index 000000000..9825414ec --- /dev/null +++ b/examples/src/main/resources/kv1.txt @@ -0,0 +1,500 @@ +238val_238 +86val_86 +311val_311 +27val_27 +165val_165 +409val_409 +255val_255 +278val_278 +98val_98 +484val_484 +265val_265 +193val_193 +401val_401 +150val_150 +273val_273 +224val_224 +369val_369 +66val_66 +128val_128 +213val_213 +146val_146 +406val_406 +429val_429 +374val_374 +152val_152 +469val_469 +145val_145 +495val_495 +37val_37 +327val_327 +281val_281 +277val_277 +209val_209 +15val_15 +82val_82 +403val_403 +166val_166 +417val_417 +430val_430 +252val_252 +292val_292 +219val_219 +287val_287 +153val_153 +193val_193 +338val_338 +446val_446 +459val_459 +394val_394 +237val_237 +482val_482 +174val_174 +413val_413 +494val_494 +207val_207 +199val_199 +466val_466 +208val_208 +174val_174 +399val_399 +396val_396 +247val_247 +417val_417 +489val_489 +162val_162 +377val_377 +397val_397 +309val_309 +365val_365 +266val_266 +439val_439 +342val_342 +367val_367 +325val_325 +167val_167 +195val_195 +475val_475 +17val_17 +113val_113 +155val_155 +203val_203 +339val_339 +0val_0 +455val_455 +128val_128 +311val_311 +316val_316 +57val_57 +302val_302 +205val_205 +149val_149 +438val_438 +345val_345 +129val_129 +170val_170 +20val_20 +489val_489 +157val_157 +378val_378 +221val_221 +92val_92 +111val_111 +47val_47 +72val_72 +4val_4 +280val_280 +35val_35 +427val_427 +277val_277 +208val_208 +356val_356 +399val_399 +169val_169 +382val_382 +498val_498 +125val_125 +386val_386 +437val_437 +469val_469 +192val_192 +286val_286 +187val_187 +176val_176 +54val_54 +459val_459 +51val_51 +138val_138 +103val_103 +239val_239 +213val_213 +216val_216 +430val_430 +278val_278 +176val_176 +289val_289 +221val_221 +65val_65 +318val_318 +332val_332 +311val_311 +275val_275 +137val_137 +241val_241 +83val_83 +333val_333 +180val_180 +284val_284 +12val_12 +230val_230 +181val_181 +67val_67 +260val_260 +404val_404 +384val_384 +489val_489 +353val_353 +373val_373 +272val_272 +138val_138 +217val_217 +84val_84 +348val_348 +466val_466 +58val_58 +8val_8 +411val_411 +230val_230 +208val_208 +348val_348 +24val_24 +463val_463 +431val_431 +179val_179 +172val_172 +42val_42 +129val_129 +158val_158 +119val_119 +496val_496 +0val_0 +322val_322 +197val_197 +468val_468 +393val_393 +454val_454 +100val_100 +298val_298 +199val_199 +191val_191 +418val_418 +96val_96 +26val_26 +165val_165 +327val_327 +230val_230 +205val_205 +120val_120 +131val_131 +51val_51 +404val_404 +43val_43 +436val_436 +156val_156 +469val_469 +468val_468 +308val_308 +95val_95 +196val_196 +288val_288 +481val_481 +457val_457 +98val_98 +282val_282 +197val_197 +187val_187 +318val_318 +318val_318 +409val_409 +470val_470 +137val_137 +369val_369 +316val_316 +169val_169 +413val_413 +85val_85 +77val_77 +0val_0 +490val_490 +87val_87 +364val_364 +179val_179 +118val_118 +134val_134 +395val_395 +282val_282 +138val_138 +238val_238 +419val_419 +15val_15 +118val_118 +72val_72 +90val_90 +307val_307 +19val_19 +435val_435 +10val_10 +277val_277 +273val_273 +306val_306 +224val_224 +309val_309 +389val_389 +327val_327 +242val_242 +369val_369 +392val_392 +272val_272 +331val_331 +401val_401 +242val_242 +452val_452 +177val_177 +226val_226 +5val_5 +497val_497 +402val_402 +396val_396 +317val_317 +395val_395 +58val_58 +35val_35 +336val_336 +95val_95 +11val_11 +168val_168 +34val_34 +229val_229 +233val_233 +143val_143 +472val_472 +322val_322 +498val_498 +160val_160 +195val_195 +42val_42 +321val_321 +430val_430 +119val_119 +489val_489 +458val_458 +78val_78 +76val_76 +41val_41 +223val_223 +492val_492 +149val_149 +449val_449 +218val_218 +228val_228 +138val_138 +453val_453 +30val_30 +209val_209 +64val_64 +468val_468 +76val_76 +74val_74 +342val_342 +69val_69 +230val_230 +33val_33 +368val_368 +103val_103 +296val_296 +113val_113 +216val_216 +367val_367 +344val_344 +167val_167 +274val_274 +219val_219 +239val_239 +485val_485 +116val_116 +223val_223 +256val_256 +263val_263 +70val_70 +487val_487 +480val_480 +401val_401 +288val_288 +191val_191 +5val_5 +244val_244 +438val_438 +128val_128 +467val_467 +432val_432 +202val_202 +316val_316 +229val_229 +469val_469 +463val_463 +280val_280 +2val_2 +35val_35 +283val_283 +331val_331 +235val_235 +80val_80 +44val_44 +193val_193 +321val_321 +335val_335 +104val_104 +466val_466 +366val_366 +175val_175 +403val_403 +483val_483 +53val_53 +105val_105 +257val_257 +406val_406 +409val_409 +190val_190 +406val_406 +401val_401 +114val_114 +258val_258 +90val_90 +203val_203 +262val_262 +348val_348 +424val_424 +12val_12 +396val_396 +201val_201 +217val_217 +164val_164 +431val_431 +454val_454 +478val_478 +298val_298 +125val_125 +431val_431 +164val_164 +424val_424 +187val_187 +382val_382 +5val_5 +70val_70 +397val_397 +480val_480 +291val_291 +24val_24 +351val_351 +255val_255 +104val_104 +70val_70 +163val_163 +438val_438 +119val_119 +414val_414 +200val_200 +491val_491 +237val_237 +439val_439 +360val_360 +248val_248 +479val_479 +305val_305 +417val_417 +199val_199 +444val_444 +120val_120 +429val_429 +169val_169 +443val_443 +323val_323 +325val_325 +277val_277 +230val_230 +478val_478 +178val_178 +468val_468 +310val_310 +317val_317 +333val_333 +493val_493 +460val_460 +207val_207 +249val_249 +265val_265 +480val_480 +83val_83 +136val_136 +353val_353 +172val_172 +214val_214 +462val_462 +233val_233 +406val_406 +133val_133 +175val_175 +189val_189 +454val_454 +375val_375 +401val_401 +421val_421 +407val_407 +384val_384 +256val_256 +26val_26 +134val_134 +67val_67 +384val_384 +379val_379 +18val_18 +462val_462 +492val_492 +100val_100 +298val_298 +9val_9 +341val_341 +498val_498 +146val_146 +458val_458 +362val_362 +186val_186 +285val_285 +348val_348 +167val_167 +18val_18 +273val_273 +183val_183 +281val_281 +344val_344 +97val_97 +469val_469 +315val_315 +84val_84 +28val_28 +37val_37 +448val_448 +152val_152 +348val_348 +307val_307 +194val_194 +414val_414 +477val_477 +222val_222 +126val_126 +90val_90 +169val_169 +403val_403 +400val_400 +200val_200 +97val_97 diff --git a/examples/src/main/resources/people.txt b/examples/src/main/resources/people.txt new file mode 100644 index 000000000..3bcace4a4 --- /dev/null +++ b/examples/src/main/resources/people.txt @@ -0,0 +1,3 @@ +Michael, 29 +Andy, 30 +Justin, 19 diff --git a/examples/src/main/scala/org/apache/spark/sql/examples/HiveFromSpark.scala b/examples/src/main/scala/org/apache/spark/sql/examples/HiveFromSpark.scala new file mode 100644 index 000000000..abcc1f04d --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/sql/examples/HiveFromSpark.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.examples + +import org.apache.spark.SparkContext +import org.apache.spark.sql._ +import org.apache.spark.sql.hive.LocalHiveContext + +object HiveFromSpark { + case class Record(key: Int, value: String) + + def main(args: Array[String]) { + val sc = new SparkContext("local", "HiveFromSpark") + + // A local hive context creates an instance of the Hive Metastore in process, storing the + // the warehouse data in the current directory. This location can be overridden by + // specifying a second parameter to the constructor. + val hiveContext = new LocalHiveContext(sc) + import hiveContext._ + + sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") + sql("LOAD DATA LOCAL INPATH 'src/main/resources/kv1.txt' INTO TABLE src") + + // Queries are expressed in HiveQL + println("Result of 'SELECT *': ") + sql("SELECT * FROM src").collect.foreach(println) + + // Aggregation queries are also supported. + val count = sql("SELECT COUNT(*) FROM src").collect().head.getInt(0) + println(s"COUNT(*): $count") + + // The results of SQL queries are themselves RDDs and support all normal RDD functions. The + // items in the RDD are of type Row, which allows you to access each column by ordinal. + val rddFromSql = sql("SELECT key, value FROM src WHERE key < 10 ORDER BY key") + + println("Result of RDD.map:") + val rddAsStrings = rddFromSql.map { + case Row(key: Int, value: String) => s"Key: $key, Value: $value" + } + + // You can also register RDDs as temporary tables within a HiveContext. + val rdd = sc.parallelize((1 to 100).map(i => Record(i, s"val_$i"))) + rdd.registerAsTable("records") + + // Queries can then join RDD data with data stored in Hive. + println("Result of SELECT *:") + sql("SELECT * FROM records r JOIN src s ON r.key = s.key").collect().foreach(println) + } +} diff --git a/examples/src/main/scala/org/apache/spark/sql/examples/RDDRelation.scala b/examples/src/main/scala/org/apache/spark/sql/examples/RDDRelation.scala new file mode 100644 index 000000000..8210ad977 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/sql/examples/RDDRelation.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.examples + +import org.apache.spark.SparkContext +import org.apache.spark.sql.SQLContext + +// One method for defining the schema of an RDD is to make a case class with the desired column +// names and types. +case class Record(key: Int, value: String) + +object RDDRelation { + def main(args: Array[String]) { + val sc = new SparkContext("local", "RDDRelation") + val sqlContext = new SQLContext(sc) + + // Importing the SQL context gives access to all the SQL functions and implicit conversions. + import sqlContext._ + + val rdd = sc.parallelize((1 to 100).map(i => Record(i, s"val_$i"))) + // Any RDD containing case classes can be registered as a table. The schema of the table is + // automatically inferred using scala reflection. + rdd.registerAsTable("records") + + // Once tables have been registered, you can run SQL queries over them. + println("Result of SELECT *:") + sql("SELECT * FROM records").collect().foreach(println) + + // Aggregation queries are also supported. + val count = sql("SELECT COUNT(*) FROM records").collect().head.getInt(0) + println(s"COUNT(*): $count") + + // The results of SQL queries are themselves RDDs and support all normal RDD functions. The + // items in the RDD are of type Row, which allows you to access each column by ordinal. + val rddFromSql = sql("SELECT key, value FROM records WHERE key < 10") + + println("Result of RDD.map:") + rddFromSql.map(row => s"Key: ${row(0)}, Value: ${row(1)}").collect.foreach(println) + + // Queries can also be written using a LINQ-like Scala DSL. + rdd.where('key === 1).orderBy('value.asc).select('key).collect().foreach(println) + + // Write out an RDD as a parquet file. + rdd.saveAsParquetFile("pair.parquet") + + // Read in parquet file. Parquet files are self-describing so the schmema is preserved. + val parquetFile = sqlContext.parquetFile("pair.parquet") + + // Queries can be run using the DSL on parequet files just like the original RDD. + parquetFile.where('key === 1).select('value as 'a).collect().foreach(println) + + // These files can also be registered as tables. + parquetFile.registerAsTable("parquetFile") + sql("SELECT * FROM parquetFile").collect().foreach(println) + } +} diff --git a/graphx/pom.xml b/graphx/pom.xml index 894a7c264..5a5022916 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -29,7 +29,7 @@ spark-graphx_2.10 jar Spark Project GraphX - http://spark-project.org/ + http://spark.apache.org/ diff --git a/pom.xml b/pom.xml index 524e5daff..9db34a01b 100644 --- a/pom.xml +++ b/pom.xml @@ -91,6 +91,9 @@ mllib tools streaming + sql/catalyst + sql/core + sql/hive repl assembly external/twitter @@ -118,6 +121,8 @@ 2.4.1 0.23.7 0.94.6 + 0.12.0 + 1.3.2 64m 512m diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index aff191c98..e4ad65912 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -52,7 +52,7 @@ object SparkBuild extends Build { lazy val core = Project("core", file("core"), settings = coreSettings) lazy val repl = Project("repl", file("repl"), settings = replSettings) - .dependsOn(core, graphx, bagel, mllib) + .dependsOn(core, graphx, bagel, mllib, sql) lazy val tools = Project("tools", file("tools"), settings = toolsSettings) dependsOn(core) dependsOn(streaming) @@ -60,12 +60,19 @@ object SparkBuild extends Build { lazy val graphx = Project("graphx", file("graphx"), settings = graphxSettings) dependsOn(core) + lazy val catalyst = Project("catalyst", file("sql/catalyst"), settings = catalystSettings) dependsOn(core) + + lazy val sql = Project("sql", file("sql/core"), settings = sqlCoreSettings) dependsOn(core, catalyst) + + // Since hive is its own assembly, it depends on all of the modules. + lazy val hive = Project("hive", file("sql/hive"), settings = hiveSettings) dependsOn(sql, graphx, bagel, mllib, streaming, repl) + lazy val streaming = Project("streaming", file("streaming"), settings = streamingSettings) dependsOn(core) lazy val mllib = Project("mllib", file("mllib"), settings = mllibSettings) dependsOn(core) lazy val assemblyProj = Project("assembly", file("assembly"), settings = assemblyProjSettings) - .dependsOn(core, graphx, bagel, mllib, repl, streaming) dependsOn(maybeYarn: _*) dependsOn(maybeGanglia: _*) + .dependsOn(core, graphx, bagel, mllib, streaming, repl, sql) dependsOn(maybeYarn: _*) dependsOn(maybeGanglia: _*) lazy val assembleDeps = TaskKey[Unit]("assemble-deps", "Build assembly of dependencies and packages Spark projects") @@ -131,13 +138,13 @@ object SparkBuild extends Build { lazy val allExternalRefs = Seq[ProjectReference](externalTwitter, externalKafka, externalFlume, externalZeromq, externalMqtt) lazy val examples = Project("examples", file("examples"), settings = examplesSettings) - .dependsOn(core, mllib, graphx, bagel, streaming, externalTwitter) dependsOn(allExternal: _*) + .dependsOn(core, mllib, graphx, bagel, streaming, externalTwitter, hive) dependsOn(allExternal: _*) - // Everything except assembly, tools, java8Tests and examples belong to packageProjects - lazy val packageProjects = Seq[ProjectReference](core, repl, bagel, streaming, mllib, graphx) ++ maybeYarnRef ++ maybeGangliaRef + // Everything except assembly, hive, tools, java8Tests and examples belong to packageProjects + lazy val packageProjects = Seq[ProjectReference](core, repl, bagel, streaming, mllib, graphx, catalyst, sql) ++ maybeYarnRef ++ maybeGangliaRef lazy val allProjects = packageProjects ++ allExternalRefs ++ - Seq[ProjectReference](examples, tools, assemblyProj) ++ maybeJava8Tests + Seq[ProjectReference](examples, tools, assemblyProj, hive) ++ maybeJava8Tests def sharedSettings = Defaults.defaultSettings ++ Seq( organization := "org.apache.spark", @@ -164,7 +171,7 @@ object SparkBuild extends Build { // Show full stack trace and duration in test cases. testOptions in Test += Tests.Argument("-oDF"), // Remove certain packages from Scaladoc - scalacOptions in (Compile,doc) := Seq("-skip-packages", Seq( + scalacOptions in (Compile,doc) := Seq("-groups", "-skip-packages", Seq( "akka", "org.apache.spark.network", "org.apache.spark.deploy", @@ -362,6 +369,61 @@ object SparkBuild extends Build { ) ) + def catalystSettings = sharedSettings ++ Seq( + name := "catalyst", + // The mechanics of rewriting expression ids to compare trees in some test cases makes + // assumptions about the the expression ids being contiguious. Running tests in parallel breaks + // this non-deterministically. TODO: FIX THIS. + parallelExecution in Test := false, + libraryDependencies ++= Seq( + "org.scalatest" %% "scalatest" % "1.9.1" % "test", + "com.typesafe" %% "scalalogging-slf4j" % "1.0.1" + ) + ) + + def sqlCoreSettings = sharedSettings ++ Seq( + name := "spark-sql", + libraryDependencies ++= Seq( + "com.twitter" % "parquet-column" % "1.3.2", + "com.twitter" % "parquet-hadoop" % "1.3.2" + ) + ) + + // Since we don't include hive in the main assembly this project also acts as an alternative + // assembly jar. + def hiveSettings = sharedSettings ++ assemblyProjSettings ++ Seq( + name := "spark-hive", + jarName in assembly <<= version map { v => "spark-hive-assembly-" + v + "-hadoop" + hadoopVersion + ".jar" }, + jarName in packageDependency <<= version map { v => "spark-hive-assembly-" + v + "-hadoop" + hadoopVersion + "-deps.jar" }, + javaOptions += "-XX:MaxPermSize=1g", + libraryDependencies ++= Seq( + "org.apache.hive" % "hive-metastore" % "0.12.0", + "org.apache.hive" % "hive-exec" % "0.12.0", + "org.apache.hive" % "hive-serde" % "0.12.0" + ), + // Multiple queries rely on the TestHive singleton. See comments there for more details. + parallelExecution in Test := false, + // Supporting all SerDes requires us to depend on deprecated APIs, so we turn off the warnings + // only for this subproject. + scalacOptions <<= scalacOptions map { currentOpts: Seq[String] => + currentOpts.filterNot(_ == "-deprecation") + }, + initialCommands in console := + """ + |import org.apache.spark.sql.catalyst.analysis._ + |import org.apache.spark.sql.catalyst.dsl._ + |import org.apache.spark.sql.catalyst.errors._ + |import org.apache.spark.sql.catalyst.expressions._ + |import org.apache.spark.sql.catalyst.plans.logical._ + |import org.apache.spark.sql.catalyst.rules._ + |import org.apache.spark.sql.catalyst.types._ + |import org.apache.spark.sql.catalyst.util._ + |import org.apache.spark.sql.execution + |import org.apache.spark.sql.hive._ + |import org.apache.spark.sql.hive.TestHive._ + |import org.apache.spark.sql.parquet.ParquetTestData""".stripMargin + ) + def streamingSettings = sharedSettings ++ Seq( name := "spark-streaming", libraryDependencies ++= Seq( diff --git a/sql/README.md b/sql/README.md new file mode 100644 index 000000000..4192fecb9 --- /dev/null +++ b/sql/README.md @@ -0,0 +1,80 @@ +Spark SQL +========= + +This module provides support for executing relational queries expressed in either SQL or a LINQ-like Scala DSL. + +Spark SQL is broken up into three subprojects: + - Catalyst (sql/catalyst) - An implementation-agnostic framework for manipulating trees of relational operators and expressions. + - Execution (sql/core) - A query planner / execution engine for translating Catalyst’s logical query plans into Spark RDDs. This component also includes a new public interface, SQLContext, that allows users to execute SQL or LINQ statements against existing RDDs and Parquet files. + - Hive Support (sql/hive) - Includes an extension of SQLContext called HiveContext that allows users to write queries using a subset of HiveQL and access data from a Hive Metastore using Hive SerDes. There are also wrappers that allows users to run queries that include Hive UDFs, UDAFs, and UDTFs. + + +Other dependencies for developers +--------------------------------- +In order to create new hive test cases , you will need to set several environmental variables. + +``` +export HIVE_HOME="/hive/build/dist" +export HIVE_DEV_HOME="/hive/" +export HADOOP_HOME="/hadoop-1.0.4" +``` + +Using the console +================= +An interactive scala console can be invoked by running `sbt/sbt hive/console`. From here you can execute queries and inspect the various stages of query optimization. + +```scala +catalyst$ sbt/sbt hive/console + +[info] Starting scala interpreter... +import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.dsl._ +import org.apache.spark.sql.catalyst.errors._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.execution +import org.apache.spark.sql.hive._ +import org.apache.spark.sql.hive.TestHive._ +Welcome to Scala version 2.10.3 (Java HotSpot(TM) 64-Bit Server VM, Java 1.7.0_45). +Type in expressions to have them evaluated. +Type :help for more information. + +scala> val query = sql("SELECT * FROM (SELECT * FROM src) a") +query: org.apache.spark.sql.ExecutedQuery = +SELECT * FROM (SELECT * FROM src) a +=== Query Plan === +Project [key#6:0.0,value#7:0.1] + HiveTableScan [key#6,value#7], (MetastoreRelation default, src, None), None +``` + +Query results are RDDs and can be operated as such. +``` +scala> query.collect() +res8: Array[org.apache.spark.sql.execution.Row] = Array([238,val_238], [86,val_86], [311,val_311]... +``` + +You can also build further queries on top of these RDDs using the query DSL. +``` +scala> query.where('key === 100).toRdd.collect() +res11: Array[org.apache.spark.sql.execution.Row] = Array([100,val_100], [100,val_100]) +``` + +From the console you can even write rules that transform query plans. For example, the above query has redundant project operators that aren't doing anything. This redundancy can be eliminated using the `transform` function that is available on all [`TreeNode`](http://databricks.github.io/catalyst/latest/api/#catalyst.trees.TreeNode) objects. +```scala +scala> query.logicalPlan +res1: catalyst.plans.logical.LogicalPlan = +Project {key#0,value#1} + Project {key#0,value#1} + MetastoreRelation default, src, None + + +scala> query.logicalPlan transform { + | case Project(projectList, child) if projectList == child.output => child + | } +res2: catalyst.plans.logical.LogicalPlan = +Project {key#0,value#1} + MetastoreRelation default, src, None +``` diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml new file mode 100644 index 000000000..740f1fdc8 --- /dev/null +++ b/sql/catalyst/pom.xml @@ -0,0 +1,66 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent + 1.0.0-SNAPSHOT + ../../pom.xml + + + org.apache.spark + spark-catalyst_2.10 + jar + Spark Project Catalyst + http://spark.apache.org/ + + + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + + + com.typesafe + scalalogging-slf4j_${scala.binary.version} + 1.0.1 + + + org.scalatest + scalatest_${scala.binary.version} + test + + + org.scalacheck + scalacheck_${scala.binary.version} + test + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + + org.scalatest + scalatest-maven-plugin + + + + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala new file mode 100644 index 000000000..d3b1070a5 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -0,0 +1,328 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst + +import scala.util.matching.Regex +import scala.util.parsing.combinator._ +import scala.util.parsing.input.CharArrayReader.EofCh +import lexical._ +import syntactical._ +import token._ + +import analysis._ +import expressions._ +import plans._ +import plans.logical._ +import types._ + +/** + * A very simple SQL parser. Based loosly on: + * https://github.com/stephentu/scala-sql-parser/blob/master/src/main/scala/parser.scala + * + * Limitations: + * - Only supports a very limited subset of SQL. + * - Keywords must be capital. + * + * This is currently included mostly for illustrative purposes. Users wanting more complete support + * for a SQL like language should checkout the HiveQL support in the sql/hive subproject. + */ +class SqlParser extends StandardTokenParsers { + + def apply(input: String): LogicalPlan = { + phrase(query)(new lexical.Scanner(input)) match { + case Success(r, x) => r + case x => sys.error(x.toString) + } + } + + protected case class Keyword(str: String) + protected implicit def asParser(k: Keyword): Parser[String] = k.str + + protected class SqlLexical extends StdLexical { + case class FloatLit(chars: String) extends Token { + override def toString = chars + } + override lazy val token: Parser[Token] = ( + identChar ~ rep( identChar | digit ) ^^ + { case first ~ rest => processIdent(first :: rest mkString "") } + | rep1(digit) ~ opt('.' ~> rep(digit)) ^^ { + case i ~ None => NumericLit(i mkString "") + case i ~ Some(d) => FloatLit(i.mkString("") + "." + d.mkString("")) + } + | '\'' ~ rep( chrExcept('\'', '\n', EofCh) ) ~ '\'' ^^ + { case '\'' ~ chars ~ '\'' => StringLit(chars mkString "") } + | '\"' ~ rep( chrExcept('\"', '\n', EofCh) ) ~ '\"' ^^ + { case '\"' ~ chars ~ '\"' => StringLit(chars mkString "") } + | EofCh ^^^ EOF + | '\'' ~> failure("unclosed string literal") + | '\"' ~> failure("unclosed string literal") + | delim + | failure("illegal character") + ) + + override def identChar = letter | elem('.') | elem('_') + + override def whitespace: Parser[Any] = rep( + whitespaceChar + | '/' ~ '*' ~ comment + | '/' ~ '/' ~ rep( chrExcept(EofCh, '\n') ) + | '#' ~ rep( chrExcept(EofCh, '\n') ) + | '-' ~ '-' ~ rep( chrExcept(EofCh, '\n') ) + | '/' ~ '*' ~ failure("unclosed comment") + ) + } + + override val lexical = new SqlLexical + + protected val ALL = Keyword("ALL") + protected val AND = Keyword("AND") + protected val AS = Keyword("AS") + protected val ASC = Keyword("ASC") + protected val AVG = Keyword("AVG") + protected val BY = Keyword("BY") + protected val CAST = Keyword("CAST") + protected val COUNT = Keyword("COUNT") + protected val DESC = Keyword("DESC") + protected val DISTINCT = Keyword("DISTINCT") + protected val FALSE = Keyword("FALSE") + protected val FIRST = Keyword("FIRST") + protected val FROM = Keyword("FROM") + protected val FULL = Keyword("FULL") + protected val GROUP = Keyword("GROUP") + protected val HAVING = Keyword("HAVING") + protected val IF = Keyword("IF") + protected val IN = Keyword("IN") + protected val INNER = Keyword("INNER") + protected val IS = Keyword("IS") + protected val JOIN = Keyword("JOIN") + protected val LEFT = Keyword("LEFT") + protected val LIMIT = Keyword("LIMIT") + protected val NOT = Keyword("NOT") + protected val NULL = Keyword("NULL") + protected val ON = Keyword("ON") + protected val OR = Keyword("OR") + protected val ORDER = Keyword("ORDER") + protected val OUTER = Keyword("OUTER") + protected val RIGHT = Keyword("RIGHT") + protected val SELECT = Keyword("SELECT") + protected val STRING = Keyword("STRING") + protected val SUM = Keyword("SUM") + protected val TRUE = Keyword("TRUE") + protected val UNION = Keyword("UNION") + protected val WHERE = Keyword("WHERE") + + // Use reflection to find the reserved words defined in this class. + protected val reservedWords = + this.getClass + .getMethods + .filter(_.getReturnType == classOf[Keyword]) + .map(_.invoke(this).asInstanceOf[Keyword]) + + lexical.reserved ++= reservedWords.map(_.str) + + lexical.delimiters += ( + "@", "*", "+", "-", "<", "=", "<>", "!=", "<=", ">=", ">", "/", "(", ")", + ",", ";", "%", "{", "}", ":" + ) + + protected def assignAliases(exprs: Seq[Expression]): Seq[NamedExpression] = { + exprs.zipWithIndex.map { + case (ne: NamedExpression, _) => ne + case (e, i) => Alias(e, s"c$i")() + } + } + + protected lazy val query: Parser[LogicalPlan] = + select * ( + UNION ~ ALL ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Union(q1, q2) } | + UNION ~ opt(DISTINCT) ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Distinct(Union(q1, q2)) } + ) + + protected lazy val select: Parser[LogicalPlan] = + SELECT ~> opt(DISTINCT) ~ projections ~ + opt(from) ~ opt(filter) ~ + opt(grouping) ~ + opt(having) ~ + opt(orderBy) ~ + opt(limit) <~ opt(";") ^^ { + case d ~ p ~ r ~ f ~ g ~ h ~ o ~ l => + val base = r.getOrElse(NoRelation) + val withFilter = f.map(f => Filter(f, base)).getOrElse(base) + val withProjection = + g.map {g => + Aggregate(assignAliases(g), assignAliases(p), withFilter) + }.getOrElse(Project(assignAliases(p), withFilter)) + val withDistinct = d.map(_ => Distinct(withProjection)).getOrElse(withProjection) + val withHaving = h.map(h => Filter(h, withDistinct)).getOrElse(withDistinct) + val withOrder = o.map(o => Sort(o, withHaving)).getOrElse(withHaving) + val withLimit = l.map { l => StopAfter(l, withOrder) }.getOrElse(withOrder) + withLimit + } + + protected lazy val projections: Parser[Seq[Expression]] = repsep(projection, ",") + + protected lazy val projection: Parser[Expression] = + expression ~ (opt(AS) ~> opt(ident)) ^^ { + case e ~ None => e + case e ~ Some(a) => Alias(e, a)() + } + + protected lazy val from: Parser[LogicalPlan] = FROM ~> relations + + // Based very loosly on the MySQL Grammar. + // http://dev.mysql.com/doc/refman/5.0/en/join.html + protected lazy val relations: Parser[LogicalPlan] = + relation ~ "," ~ relation ^^ { case r1 ~ _ ~ r2 => Join(r1, r2, Inner, None) } | + relation + + protected lazy val relation: Parser[LogicalPlan] = + joinedRelation | + relationFactor + + protected lazy val relationFactor: Parser[LogicalPlan] = + ident ~ (opt(AS) ~> opt(ident)) ^^ { + case ident ~ alias => UnresolvedRelation(alias, ident) + } | + "(" ~> query ~ ")" ~ opt(AS) ~ ident ^^ { case s ~ _ ~ _ ~ a => Subquery(a, s) } + + protected lazy val joinedRelation: Parser[LogicalPlan] = + relationFactor ~ opt(joinType) ~ JOIN ~ relationFactor ~ opt(joinConditions) ^^ { + case r1 ~ jt ~ _ ~ r2 ~ cond => + Join(r1, r2, joinType = jt.getOrElse(Inner), cond) + } + + protected lazy val joinConditions: Parser[Expression] = + ON ~> expression + + protected lazy val joinType: Parser[JoinType] = + INNER ^^^ Inner | + LEFT ~ opt(OUTER) ^^^ LeftOuter | + RIGHT ~ opt(OUTER) ^^^ RightOuter | + FULL ~ opt(OUTER) ^^^ FullOuter + + protected lazy val filter: Parser[Expression] = WHERE ~ expression ^^ { case _ ~ e => e } + + protected lazy val orderBy: Parser[Seq[SortOrder]] = + ORDER ~> BY ~> ordering + + protected lazy val ordering: Parser[Seq[SortOrder]] = + rep1sep(singleOrder, ",") | + rep1sep(expression, ",") ~ opt(direction) ^^ { + case exps ~ None => exps.map(SortOrder(_, Ascending)) + case exps ~ Some(d) => exps.map(SortOrder(_, d)) + } + + protected lazy val singleOrder: Parser[SortOrder] = + expression ~ direction ^^ { case e ~ o => SortOrder(e,o) } + + protected lazy val direction: Parser[SortDirection] = + ASC ^^^ Ascending | + DESC ^^^ Descending + + protected lazy val grouping: Parser[Seq[Expression]] = + GROUP ~> BY ~> rep1sep(expression, ",") + + protected lazy val having: Parser[Expression] = + HAVING ~> expression + + protected lazy val limit: Parser[Expression] = + LIMIT ~> expression + + protected lazy val expression: Parser[Expression] = orExpression + + protected lazy val orExpression: Parser[Expression] = + andExpression * (OR ^^^ { (e1: Expression, e2: Expression) => Or(e1,e2) }) + + protected lazy val andExpression: Parser[Expression] = + comparisionExpression * (AND ^^^ { (e1: Expression, e2: Expression) => And(e1,e2) }) + + protected lazy val comparisionExpression: Parser[Expression] = + termExpression ~ "=" ~ termExpression ^^ { case e1 ~ _ ~ e2 => Equals(e1, e2) } | + termExpression ~ "<" ~ termExpression ^^ { case e1 ~ _ ~ e2 => LessThan(e1, e2) } | + termExpression ~ "<=" ~ termExpression ^^ { case e1 ~ _ ~ e2 => LessThanOrEqual(e1, e2) } | + termExpression ~ ">" ~ termExpression ^^ { case e1 ~ _ ~ e2 => GreaterThan(e1, e2) } | + termExpression ~ ">=" ~ termExpression ^^ { case e1 ~ _ ~ e2 => GreaterThanOrEqual(e1, e2) } | + termExpression ~ "!=" ~ termExpression ^^ { case e1 ~ _ ~ e2 => Not(Equals(e1, e2)) } | + termExpression ~ "<>" ~ termExpression ^^ { case e1 ~ _ ~ e2 => Not(Equals(e1, e2)) } | + termExpression ~ IN ~ "(" ~ rep1sep(termExpression, ",") <~ ")" ^^ { + case e1 ~ _ ~ _ ~ e2 => In(e1, e2) + } | + termExpression ~ NOT ~ IN ~ "(" ~ rep1sep(termExpression, ",") <~ ")" ^^ { + case e1 ~ _ ~ _ ~ _ ~ e2 => Not(In(e1, e2)) + } | + termExpression <~ IS ~ NULL ^^ { case e => IsNull(e) } | + termExpression <~ IS ~ NOT ~ NULL ^^ { case e => IsNotNull(e) } | + NOT ~> termExpression ^^ {e => Not(e)} | + termExpression + + protected lazy val termExpression: Parser[Expression] = + productExpression * ( + "+" ^^^ { (e1: Expression, e2: Expression) => Add(e1,e2) } | + "-" ^^^ { (e1: Expression, e2: Expression) => Subtract(e1,e2) } ) + + protected lazy val productExpression: Parser[Expression] = + baseExpression * ( + "*" ^^^ { (e1: Expression, e2: Expression) => Multiply(e1,e2) } | + "/" ^^^ { (e1: Expression, e2: Expression) => Divide(e1,e2) } | + "%" ^^^ { (e1: Expression, e2: Expression) => Remainder(e1,e2) } + ) + + protected lazy val function: Parser[Expression] = + SUM ~> "(" ~> expression <~ ")" ^^ { case exp => Sum(exp) } | + SUM ~> "(" ~> DISTINCT ~> expression <~ ")" ^^ { case exp => SumDistinct(exp) } | + COUNT ~> "(" ~ "*" <~ ")" ^^ { case _ => Count(Literal(1)) } | + COUNT ~> "(" ~ expression <~ ")" ^^ { case dist ~ exp => Count(exp) } | + COUNT ~> "(" ~> DISTINCT ~> expression <~ ")" ^^ { case exp => CountDistinct(exp :: Nil) } | + FIRST ~> "(" ~> expression <~ ")" ^^ { case exp => First(exp) } | + AVG ~> "(" ~> expression <~ ")" ^^ { case exp => Average(exp) } | + IF ~> "(" ~> expression ~ "," ~ expression ~ "," ~ expression <~ ")" ^^ { + case c ~ "," ~ t ~ "," ~ f => If(c,t,f) + } | + ident ~ "(" ~ repsep(expression, ",") <~ ")" ^^ { + case udfName ~ _ ~ exprs => UnresolvedFunction(udfName, exprs) + } + + protected lazy val cast: Parser[Expression] = + CAST ~> "(" ~> expression ~ AS ~ dataType <~ ")" ^^ { case exp ~ _ ~ t => Cast(exp, t) } + + protected lazy val literal: Parser[Literal] = + numericLit ^^ { + case i if i.toLong > Int.MaxValue => Literal(i.toLong) + case i => Literal(i.toInt) + } | + NULL ^^^ Literal(null, NullType) | + floatLit ^^ {case f => Literal(f.toDouble) } | + stringLit ^^ {case s => Literal(s, StringType) } + + protected lazy val floatLit: Parser[String] = + elem("decimal", _.isInstanceOf[lexical.FloatLit]) ^^ (_.chars) + + protected lazy val baseExpression: Parser[Expression] = + TRUE ^^^ Literal(true, BooleanType) | + FALSE ^^^ Literal(false, BooleanType) | + cast | + "(" ~> expression <~ ")" | + function | + "-" ~> literal ^^ UnaryMinus | + ident ^^ UnresolvedAttribute | + "*" ^^^ Star(None) | + literal + + protected lazy val dataType: Parser[DataType] = + STRING ^^^ StringType +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala new file mode 100644 index 000000000..9eb992ee5 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package analysis + +import expressions._ +import plans.logical._ +import rules._ + +/** + * A trivial [[Analyzer]] with an [[EmptyCatalog]] and [[EmptyFunctionRegistry]]. Used for testing + * when all relations are already filled in and the analyser needs only to resolve attribute + * references. + */ +object SimpleAnalyzer extends Analyzer(EmptyCatalog, EmptyFunctionRegistry, true) + +/** + * Provides a logical query plan analyzer, which translates [[UnresolvedAttribute]]s and + * [[UnresolvedRelation]]s into fully typed objects using information in a schema [[Catalog]] and + * a [[FunctionRegistry]]. + */ +class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Boolean) + extends RuleExecutor[LogicalPlan] with HiveTypeCoercion { + + // TODO: pass this in as a parameter. + val fixedPoint = FixedPoint(100) + + val batches: Seq[Batch] = Seq( + Batch("MultiInstanceRelations", Once, + NewRelationInstances), + Batch("CaseInsensitiveAttributeReferences", Once, + (if (caseSensitive) Nil else LowercaseAttributeReferences :: Nil) : _*), + Batch("Resolution", fixedPoint, + ResolveReferences :: + ResolveRelations :: + NewRelationInstances :: + ImplicitGenerate :: + StarExpansion :: + ResolveFunctions :: + GlobalAggregates :: + typeCoercionRules :_*) + ) + + /** + * Replaces [[UnresolvedRelation]]s with concrete relations from the catalog. + */ + object ResolveRelations extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case UnresolvedRelation(databaseName, name, alias) => + catalog.lookupRelation(databaseName, name, alias) + } + } + + /** + * Makes attribute naming case insensitive by turning all UnresolvedAttributes to lowercase. + */ + object LowercaseAttributeReferences extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case UnresolvedRelation(databaseName, name, alias) => + UnresolvedRelation(databaseName, name, alias.map(_.toLowerCase)) + case Subquery(alias, child) => Subquery(alias.toLowerCase, child) + case q: LogicalPlan => q transformExpressions { + case s: Star => s.copy(table = s.table.map(_.toLowerCase)) + case UnresolvedAttribute(name) => UnresolvedAttribute(name.toLowerCase) + case Alias(c, name) => Alias(c, name.toLowerCase)() + } + } + } + + /** + * Replaces [[UnresolvedAttribute]]s with concrete + * [[expressions.AttributeReference AttributeReferences]] from a logical plan node's children. + */ + object ResolveReferences extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + case q: LogicalPlan if q.childrenResolved => + logger.trace(s"Attempting to resolve ${q.simpleString}") + q transformExpressions { + case u @ UnresolvedAttribute(name) => + // Leave unchanged if resolution fails. Hopefully will be resolved next round. + val result = q.resolve(name).getOrElse(u) + logger.debug(s"Resolving $u to $result") + result + } + } + } + + /** + * Replaces [[UnresolvedFunction]]s with concrete [[expressions.Expression Expressions]]. + */ + object ResolveFunctions extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case q: LogicalPlan => + q transformExpressions { + case u @ UnresolvedFunction(name, children) if u.childrenResolved => + registry.lookupFunction(name, children) + } + } + } + + /** + * Turns projections that contain aggregate expressions into aggregations. + */ + object GlobalAggregates extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case Project(projectList, child) if containsAggregates(projectList) => + Aggregate(Nil, projectList, child) + } + + def containsAggregates(exprs: Seq[Expression]): Boolean = { + exprs.foreach(_.foreach { + case agg: AggregateExpression => return true + case _ => + }) + false + } + } + + /** + * When a SELECT clause has only a single expression and that expression is a + * [[catalyst.expressions.Generator Generator]] we convert the + * [[catalyst.plans.logical.Project Project]] to a [[catalyst.plans.logical.Generate Generate]]. + */ + object ImplicitGenerate extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case Project(Seq(Alias(g: Generator, _)), child) => + Generate(g, join = false, outer = false, None, child) + } + } + + /** + * Expands any references to [[Star]] (*) in project operators. + */ + object StarExpansion extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + // Wait until children are resolved + case p: LogicalPlan if !p.childrenResolved => p + // If the projection list contains Stars, expand it. + case p @ Project(projectList, child) if containsStar(projectList) => + Project( + projectList.flatMap { + case s: Star => s.expand(child.output) + case o => o :: Nil + }, + child) + case t: ScriptTransformation if containsStar(t.input) => + t.copy( + input = t.input.flatMap { + case s: Star => s.expand(t.child.output) + case o => o :: Nil + } + ) + // If the aggregate function argument contains Stars, expand it. + case a: Aggregate if containsStar(a.aggregateExpressions) => + a.copy( + aggregateExpressions = a.aggregateExpressions.flatMap { + case s: Star => s.expand(a.child.output) + case o => o :: Nil + } + ) + } + + /** + * Returns true if `exprs` contains a [[Star]]. + */ + protected def containsStar(exprs: Seq[Expression]): Boolean = + exprs.collect { case _: Star => true }.nonEmpty + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala new file mode 100644 index 000000000..71e4dcdb1 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package analysis + +import plans.logical.{LogicalPlan, Subquery} +import scala.collection.mutable + +/** + * An interface for looking up relations by name. Used by an [[Analyzer]]. + */ +trait Catalog { + def lookupRelation( + databaseName: Option[String], + tableName: String, + alias: Option[String] = None): LogicalPlan + + def registerTable(databaseName: Option[String], tableName: String, plan: LogicalPlan): Unit +} + +class SimpleCatalog extends Catalog { + val tables = new mutable.HashMap[String, LogicalPlan]() + + def registerTable(databaseName: Option[String],tableName: String, plan: LogicalPlan): Unit = { + tables += ((tableName, plan)) + } + + def dropTable(tableName: String) = tables -= tableName + + def lookupRelation( + databaseName: Option[String], + tableName: String, + alias: Option[String] = None): LogicalPlan = { + val table = tables.get(tableName).getOrElse(sys.error(s"Table Not Found: $tableName")) + + // If an alias was specified by the lookup, wrap the plan in a subquery so that attributes are + // properly qualified with this alias. + alias.map(a => Subquery(a.toLowerCase, table)).getOrElse(table) + } +} + +/** + * A trait that can be mixed in with other Catalogs allowing specific tables to be overridden with + * new logical plans. This can be used to bind query result to virtual tables, or replace tables + * with in-memory cached versions. Note that the set of overrides is stored in memory and thus + * lost when the JVM exits. + */ +trait OverrideCatalog extends Catalog { + + // TODO: This doesn't work when the database changes... + val overrides = new mutable.HashMap[(Option[String],String), LogicalPlan]() + + abstract override def lookupRelation( + databaseName: Option[String], + tableName: String, + alias: Option[String] = None): LogicalPlan = { + + val overriddenTable = overrides.get((databaseName, tableName)) + + // If an alias was specified by the lookup, wrap the plan in a subquery so that attributes are + // properly qualified with this alias. + val withAlias = + overriddenTable.map(r => alias.map(a => Subquery(a.toLowerCase, r)).getOrElse(r)) + + withAlias.getOrElse(super.lookupRelation(databaseName, tableName, alias)) + } + + override def registerTable( + databaseName: Option[String], + tableName: String, + plan: LogicalPlan): Unit = { + overrides.put((databaseName, tableName), plan) + } +} + +/** + * A trivial catalog that returns an error when a relation is requested. Used for testing when all + * relations are already filled in and the analyser needs only to resolve attribute references. + */ +object EmptyCatalog extends Catalog { + def lookupRelation( + databaseName: Option[String], + tableName: String, + alias: Option[String] = None) = { + throw new UnsupportedOperationException + } + + def registerTable(databaseName: Option[String], tableName: String, plan: LogicalPlan): Unit = { + throw new UnsupportedOperationException + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala new file mode 100644 index 000000000..a359eb541 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package analysis + +import expressions._ + +/** A catalog for looking up user defined functions, used by an [[Analyzer]]. */ +trait FunctionRegistry { + def lookupFunction(name: String, children: Seq[Expression]): Expression +} + +/** + * A trivial catalog that returns an error when a function is requested. Used for testing when all + * functions are already filled in and the analyser needs only to resolve attribute references. + */ +object EmptyFunctionRegistry extends FunctionRegistry { + def lookupFunction(name: String, children: Seq[Expression]): Expression = { + throw new UnsupportedOperationException + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala new file mode 100644 index 000000000..a0105cd7c --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -0,0 +1,275 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package analysis + +import expressions._ +import plans.logical._ +import rules._ +import types._ + +/** + * A collection of [[catalyst.rules.Rule Rules]] that can be used to coerce differing types that + * participate in operations into compatible ones. Most of these rules are based on Hive semantics, + * but they do not introduce any dependencies on the hive codebase. For this reason they remain in + * Catalyst until we have a more standard set of coercions. + */ +trait HiveTypeCoercion { + + val typeCoercionRules = + List(PropagateTypes, ConvertNaNs, WidenTypes, PromoteStrings, BooleanComparisons, BooleanCasts, + StringToIntegralCasts, FunctionArgumentConversion) + + /** + * Applies any changes to [[catalyst.expressions.AttributeReference AttributeReference]] dataTypes + * that are made by other rules to instances higher in the query tree. + */ + object PropagateTypes extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + // No propagation required for leaf nodes. + case q: LogicalPlan if q.children.isEmpty => q + + // Don't propagate types from unresolved children. + case q: LogicalPlan if !q.childrenResolved => q + + case q: LogicalPlan => q transformExpressions { + case a: AttributeReference => + q.inputSet.find(_.exprId == a.exprId) match { + // This can happen when a Attribute reference is born in a non-leaf node, for example + // due to a call to an external script like in the Transform operator. + // TODO: Perhaps those should actually be aliases? + case None => a + // Leave the same if the dataTypes match. + case Some(newType) if a.dataType == newType.dataType => a + case Some(newType) => + logger.debug(s"Promoting $a to $newType in ${q.simpleString}}") + newType + } + } + } + } + + /** + * Converts string "NaN"s that are in binary operators with a NaN-able types (Float / Double) to + * the appropriate numeric equivalent. + */ + object ConvertNaNs extends Rule[LogicalPlan] { + val stringNaN = Literal("NaN", StringType) + + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case q: LogicalPlan => q transformExpressions { + // Skip nodes who's children have not been resolved yet. + case e if !e.childrenResolved => e + + /* Double Conversions */ + case b: BinaryExpression if b.left == stringNaN && b.right.dataType == DoubleType => + b.makeCopy(Array(b.right, Literal(Double.NaN))) + case b: BinaryExpression if b.left.dataType == DoubleType && b.right == stringNaN => + b.makeCopy(Array(Literal(Double.NaN), b.left)) + case b: BinaryExpression if b.left == stringNaN && b.right == stringNaN => + b.makeCopy(Array(Literal(Double.NaN), b.left)) + + /* Float Conversions */ + case b: BinaryExpression if b.left == stringNaN && b.right.dataType == FloatType => + b.makeCopy(Array(b.right, Literal(Float.NaN))) + case b: BinaryExpression if b.left.dataType == FloatType && b.right == stringNaN => + b.makeCopy(Array(Literal(Float.NaN), b.left)) + case b: BinaryExpression if b.left == stringNaN && b.right == stringNaN => + b.makeCopy(Array(Literal(Float.NaN), b.left)) + } + } + } + + /** + * Widens numeric types and converts strings to numbers when appropriate. + * + * Loosely based on rules from "Hadoop: The Definitive Guide" 2nd edition, by Tom White + * + * The implicit conversion rules can be summarized as follows: + * - Any integral numeric type can be implicitly converted to a wider type. + * - All the integral numeric types, FLOAT, and (perhaps surprisingly) STRING can be implicitly + * converted to DOUBLE. + * - TINYINT, SMALLINT, and INT can all be converted to FLOAT. + * - BOOLEAN types cannot be converted to any other type. + * + * Additionally, all types when UNION-ed with strings will be promoted to strings. + * Other string conversions are handled by PromoteStrings. + */ + object WidenTypes extends Rule[LogicalPlan] { + // See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types. + // The conversion for integral and floating point types have a linear widening hierarchy: + val numericPrecedence = + Seq(NullType, ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, DecimalType) + // Boolean is only wider than Void + val booleanPrecedence = Seq(NullType, BooleanType) + val allPromotions: Seq[Seq[DataType]] = numericPrecedence :: booleanPrecedence :: Nil + + def findTightestCommonType(t1: DataType, t2: DataType): Option[DataType] = { + // Try and find a promotion rule that contains both types in question. + val applicableConversion = allPromotions.find(p => p.contains(t1) && p.contains(t2)) + + // If found return the widest common type, otherwise None + applicableConversion.map(_.filter(t => t == t1 || t == t2).last) + } + + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case u @ Union(left, right) if u.childrenResolved && !u.resolved => + val castedInput = left.output.zip(right.output).map { + // When a string is found on one side, make the other side a string too. + case (l, r) if l.dataType == StringType && r.dataType != StringType => + (l, Alias(Cast(r, StringType), r.name)()) + case (l, r) if l.dataType != StringType && r.dataType == StringType => + (Alias(Cast(l, StringType), l.name)(), r) + + case (l, r) if l.dataType != r.dataType => + logger.debug(s"Resolving mismatched union input ${l.dataType}, ${r.dataType}") + findTightestCommonType(l.dataType, r.dataType).map { widestType => + val newLeft = + if (l.dataType == widestType) l else Alias(Cast(l, widestType), l.name)() + val newRight = + if (r.dataType == widestType) r else Alias(Cast(r, widestType), r.name)() + + (newLeft, newRight) + }.getOrElse((l, r)) // If there is no applicable conversion, leave expression unchanged. + case other => other + } + + val (castedLeft, castedRight) = castedInput.unzip + + val newLeft = + if (castedLeft.map(_.dataType) != left.output.map(_.dataType)) { + logger.debug(s"Widening numeric types in union $castedLeft ${left.output}") + Project(castedLeft, left) + } else { + left + } + + val newRight = + if (castedRight.map(_.dataType) != right.output.map(_.dataType)) { + logger.debug(s"Widening numeric types in union $castedRight ${right.output}") + Project(castedRight, right) + } else { + right + } + + Union(newLeft, newRight) + + // Also widen types for BinaryExpressions. + case q: LogicalPlan => q transformExpressions { + // Skip nodes who's children have not been resolved yet. + case e if !e.childrenResolved => e + + case b: BinaryExpression if b.left.dataType != b.right.dataType => + findTightestCommonType(b.left.dataType, b.right.dataType).map { widestType => + val newLeft = + if (b.left.dataType == widestType) b.left else Cast(b.left, widestType) + val newRight = + if (b.right.dataType == widestType) b.right else Cast(b.right, widestType) + b.makeCopy(Array(newLeft, newRight)) + }.getOrElse(b) // If there is no applicable conversion, leave expression unchanged. + } + } + } + + /** + * Promotes strings that appear in arithmetic expressions. + */ + object PromoteStrings extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + // Skip nodes who's children have not been resolved yet. + case e if !e.childrenResolved => e + + case a: BinaryArithmetic if a.left.dataType == StringType => + a.makeCopy(Array(Cast(a.left, DoubleType), a.right)) + case a: BinaryArithmetic if a.right.dataType == StringType => + a.makeCopy(Array(a.left, Cast(a.right, DoubleType))) + + case p: BinaryPredicate if p.left.dataType == StringType && p.right.dataType != StringType => + p.makeCopy(Array(Cast(p.left, DoubleType), p.right)) + case p: BinaryPredicate if p.left.dataType != StringType && p.right.dataType == StringType => + p.makeCopy(Array(p.left, Cast(p.right, DoubleType))) + + case Sum(e) if e.dataType == StringType => + Sum(Cast(e, DoubleType)) + case Average(e) if e.dataType == StringType => + Average(Cast(e, DoubleType)) + } + } + + /** + * Changes Boolean values to Bytes so that expressions like true < false can be Evaluated. + */ + object BooleanComparisons extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + // Skip nodes who's children have not been resolved yet. + case e if !e.childrenResolved => e + // No need to change Equals operators as that actually makes sense for boolean types. + case e: Equals => e + // Otherwise turn them to Byte types so that there exists and ordering. + case p: BinaryComparison + if p.left.dataType == BooleanType && p.right.dataType == BooleanType => + p.makeCopy(Array(Cast(p.left, ByteType), Cast(p.right, ByteType))) + } + } + + /** + * Casts to/from [[catalyst.types.BooleanType BooleanType]] are transformed into comparisons since + * the JVM does not consider Booleans to be numeric types. + */ + object BooleanCasts extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + // Skip nodes who's children have not been resolved yet. + case e if !e.childrenResolved => e + + case Cast(e, BooleanType) => Not(Equals(e, Literal(0))) + case Cast(e, dataType) if e.dataType == BooleanType => + Cast(If(e, Literal(1), Literal(0)), dataType) + } + } + + /** + * When encountering a cast from a string representing a valid fractional number to an integral + * type the jvm will throw a `java.lang.NumberFormatException`. Hive, in contrast, returns the + * truncated version of this number. + */ + object StringToIntegralCasts extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + // Skip nodes who's children have not been resolved yet. + case e if !e.childrenResolved => e + + case Cast(e @ StringType(), t: IntegralType) => + Cast(Cast(e, DecimalType), t) + } + } + + /** + * This ensure that the types for various functions are as expected. + */ + object FunctionArgumentConversion extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + // Skip nodes who's children have not been resolved yet. + case e if !e.childrenResolved => e + + // Promote SUM to largest types to prevent overflows. + case s @ Sum(e @ DecimalType()) => s // Decimal is already the biggest. + case Sum(e @ IntegralType()) if e.dataType != LongType => Sum(Cast(e, LongType)) + case Sum(e @ FractionalType()) if e.dataType != DoubleType => Sum(Cast(e, DoubleType)) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala new file mode 100644 index 000000000..fe18cc466 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst +package analysis + +import plans.logical.LogicalPlan +import rules._ + +/** + * A trait that should be mixed into query operators where an single instance might appear multiple + * times in a logical query plan. It is invalid to have multiple copies of the same attribute + * produced by distinct operators in a query tree as this breaks the gurantee that expression + * ids, which are used to differentate attributes, are unique. + * + * Before analysis, all operators that include this trait will be asked to produce a new version + * of itself with globally unique expression ids. + */ +trait MultiInstanceRelation { + def newInstance: this.type +} + +/** + * If any MultiInstanceRelation appears more than once in the query plan then the plan is updated so + * that each instance has unique expression ids for the attributes produced. + */ +object NewRelationInstances extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = { + val localRelations = plan collect { case l: MultiInstanceRelation => l} + val multiAppearance = localRelations + .groupBy(identity[MultiInstanceRelation]) + .filter { case (_, ls) => ls.size > 1 } + .map(_._1) + .toSet + + plan transform { + case l: MultiInstanceRelation if multiAppearance contains l => l.newInstance + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala new file mode 100644 index 000000000..375c99f48 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package catalyst + +/** + * Provides a logical query plan [[Analyzer]] and supporting classes for performing analysis. + * Analysis consists of translating [[UnresolvedAttribute]]s and [[UnresolvedRelation]]s + * into fully typed objects using information in a schema [[Catalog]]. + */ +package object analysis diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala new file mode 100644 index 000000000..2ed2af135 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package analysis + +import expressions._ +import plans.logical.BaseRelation +import trees.TreeNode + +/** + * Thrown when an invalid attempt is made to access a property of a tree that has yet to be fully + * resolved. + */ +class UnresolvedException[TreeType <: TreeNode[_]](tree: TreeType, function: String) extends + errors.TreeNodeException(tree, s"Invalid call to $function on unresolved object", null) + +/** + * Holds the name of a relation that has yet to be looked up in a [[Catalog]]. + */ +case class UnresolvedRelation( + databaseName: Option[String], + tableName: String, + alias: Option[String] = None) extends BaseRelation { + def output = Nil + override lazy val resolved = false +} + +/** + * Holds the name of an attribute that has yet to be resolved. + */ +case class UnresolvedAttribute(name: String) extends Attribute with trees.LeafNode[Expression] { + def exprId = throw new UnresolvedException(this, "exprId") + def dataType = throw new UnresolvedException(this, "dataType") + def nullable = throw new UnresolvedException(this, "nullable") + def qualifiers = throw new UnresolvedException(this, "qualifiers") + override lazy val resolved = false + + def newInstance = this + def withQualifiers(newQualifiers: Seq[String]) = this + + override def toString: String = s"'$name" +} + +case class UnresolvedFunction(name: String, children: Seq[Expression]) extends Expression { + def exprId = throw new UnresolvedException(this, "exprId") + def dataType = throw new UnresolvedException(this, "dataType") + override def foldable = throw new UnresolvedException(this, "foldable") + def nullable = throw new UnresolvedException(this, "nullable") + def qualifiers = throw new UnresolvedException(this, "qualifiers") + def references = children.flatMap(_.references).toSet + override lazy val resolved = false + override def toString = s"'$name(${children.mkString(",")})" +} + +/** + * Represents all of the input attributes to a given relational operator, for example in + * "SELECT * FROM ...". + * + * @param table an optional table that should be the target of the expansion. If omitted all + * tables' columns are produced. + */ +case class Star( + table: Option[String], + mapFunction: Attribute => Expression = identity[Attribute]) + extends Attribute with trees.LeafNode[Expression] { + + def name = throw new UnresolvedException(this, "exprId") + def exprId = throw new UnresolvedException(this, "exprId") + def dataType = throw new UnresolvedException(this, "dataType") + def nullable = throw new UnresolvedException(this, "nullable") + def qualifiers = throw new UnresolvedException(this, "qualifiers") + override lazy val resolved = false + + def newInstance = this + def withQualifiers(newQualifiers: Seq[String]) = this + + def expand(input: Seq[Attribute]): Seq[NamedExpression] = { + val expandedAttributes: Seq[Attribute] = table match { + // If there is no table specified, use all input attributes. + case None => input + // If there is a table, pick out attributes that are part of this table. + case Some(table) => input.filter(_.qualifiers contains table) + } + val mappedAttributes = expandedAttributes.map(mapFunction).zip(input).map { + case (n: NamedExpression, _) => n + case (e, originalAttribute) => + Alias(e, originalAttribute.name)(qualifiers = originalAttribute.qualifiers) + } + mappedAttributes + } + + override def toString = table.map(_ + ".").getOrElse("") + "*" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala new file mode 100644 index 000000000..cd8de9d52 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -0,0 +1,224 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst + +import scala.language.implicitConversions +import scala.reflect.runtime.universe.TypeTag + +import analysis.UnresolvedAttribute +import expressions._ +import plans._ +import plans.logical._ +import types._ + +/** + * Provides experimental support for generating catalyst schemas for scala objects. + */ +object ScalaReflection { + import scala.reflect.runtime.universe._ + + /** Returns a Sequence of attributes for the given case class type. */ + def attributesFor[T: TypeTag]: Seq[Attribute] = schemaFor[T] match { + case s: StructType => + s.fields.map(f => AttributeReference(f.name, f.dataType, nullable = true)()) + } + + /** Returns a catalyst DataType for the given Scala Type using reflection. */ + def schemaFor[T: TypeTag]: DataType = schemaFor(typeOf[T]) + + /** Returns a catalyst DataType for the given Scala Type using reflection. */ + def schemaFor(tpe: `Type`): DataType = tpe match { + case t if t <:< typeOf[Product] => + val params = t.member("": TermName).asMethod.paramss + StructType( + params.head.map(p => StructField(p.name.toString, schemaFor(p.typeSignature), true))) + case t if t <:< typeOf[Seq[_]] => + val TypeRef(_, _, Seq(elementType)) = t + ArrayType(schemaFor(elementType)) + case t if t <:< typeOf[String] => StringType + case t if t <:< definitions.IntTpe => IntegerType + case t if t <:< definitions.LongTpe => LongType + case t if t <:< definitions.DoubleTpe => DoubleType + case t if t <:< definitions.ShortTpe => ShortType + case t if t <:< definitions.ByteTpe => ByteType + } + + implicit class CaseClassRelation[A <: Product : TypeTag](data: Seq[A]) { + + /** + * Implicitly added to Sequences of case class objects. Returns a catalyst logical relation + * for the the data in the sequence. + */ + def asRelation: LocalRelation = { + val output = attributesFor[A] + LocalRelation(output, data) + } + } +} + +/** + * A collection of implicit conversions that create a DSL for constructing catalyst data structures. + * + * {{{ + * scala> import catalyst.dsl._ + * + * // Standard operators are added to expressions. + * scala> Literal(1) + Literal(1) + * res1: catalyst.expressions.Add = (1 + 1) + * + * // There is a conversion from 'symbols to unresolved attributes. + * scala> 'a.attr + * res2: catalyst.analysis.UnresolvedAttribute = 'a + * + * // These unresolved attributes can be used to create more complicated expressions. + * scala> 'a === 'b + * res3: catalyst.expressions.Equals = ('a = 'b) + * + * // SQL verbs can be used to construct logical query plans. + * scala> TestRelation('key.int, 'value.string).where('key === 1).select('value).analyze + * res4: catalyst.plans.logical.LogicalPlan = + * Project {value#1} + * Filter (key#0 = 1) + * TestRelation {key#0,value#1} + * }}} + */ +package object dsl { + trait ImplicitOperators { + def expr: Expression + + def + (other: Expression) = Add(expr, other) + def - (other: Expression) = Subtract(expr, other) + def * (other: Expression) = Multiply(expr, other) + def / (other: Expression) = Divide(expr, other) + + def && (other: Expression) = And(expr, other) + def || (other: Expression) = Or(expr, other) + + def < (other: Expression) = LessThan(expr, other) + def <= (other: Expression) = LessThanOrEqual(expr, other) + def > (other: Expression) = GreaterThan(expr, other) + def >= (other: Expression) = GreaterThanOrEqual(expr, other) + def === (other: Expression) = Equals(expr, other) + def != (other: Expression) = Not(Equals(expr, other)) + + def asc = SortOrder(expr, Ascending) + def desc = SortOrder(expr, Descending) + + def as(s: Symbol) = Alias(expr, s.name)() + } + + trait ExpressionConversions { + implicit class DslExpression(e: Expression) extends ImplicitOperators { + def expr = e + } + + implicit def intToLiteral(i: Int) = Literal(i) + implicit def longToLiteral(l: Long) = Literal(l) + implicit def floatToLiteral(f: Float) = Literal(f) + implicit def doubleToLiteral(d: Double) = Literal(d) + implicit def stringToLiteral(s: String) = Literal(s) + + implicit def symbolToUnresolvedAttribute(s: Symbol) = analysis.UnresolvedAttribute(s.name) + + implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s = sym.name } + implicit class DslString(val s: String) extends ImplicitAttribute + + abstract class ImplicitAttribute extends ImplicitOperators { + def s: String + def expr = attr + def attr = analysis.UnresolvedAttribute(s) + + /** Creates a new typed attributes of type int */ + def int = AttributeReference(s, IntegerType, nullable = false)() + + /** Creates a new typed attributes of type string */ + def string = AttributeReference(s, StringType, nullable = false)() + } + + implicit class DslAttribute(a: AttributeReference) { + def notNull = a.withNullability(false) + def nullable = a.withNullability(true) + + // Protobuf terminology + def required = a.withNullability(false) + } + } + + + object expressions extends ExpressionConversions // scalastyle:ignore + + abstract class LogicalPlanFunctions { + def logicalPlan: LogicalPlan + + def select(exprs: NamedExpression*) = Project(exprs, logicalPlan) + + def where(condition: Expression) = Filter(condition, logicalPlan) + + def join( + otherPlan: LogicalPlan, + joinType: JoinType = Inner, + condition: Option[Expression] = None) = + Join(logicalPlan, otherPlan, joinType, condition) + + def orderBy(sortExprs: SortOrder*) = Sort(sortExprs, logicalPlan) + + def groupBy(groupingExprs: Expression*)(aggregateExprs: Expression*) = { + val aliasedExprs = aggregateExprs.map { + case ne: NamedExpression => ne + case e => Alias(e, e.toString)() + } + Aggregate(groupingExprs, aliasedExprs, logicalPlan) + } + + def subquery(alias: Symbol) = Subquery(alias.name, logicalPlan) + + def unionAll(otherPlan: LogicalPlan) = Union(logicalPlan, otherPlan) + + def sfilter[T1](arg1: Symbol)(udf: (T1) => Boolean) = + Filter(ScalaUdf(udf, BooleanType, Seq(UnresolvedAttribute(arg1.name))), logicalPlan) + + def sfilter(dynamicUdf: (DynamicRow) => Boolean) = + Filter(ScalaUdf(dynamicUdf, BooleanType, Seq(WrapDynamic(logicalPlan.output))), logicalPlan) + + def sample( + fraction: Double, + withReplacement: Boolean = true, + seed: Int = (math.random * 1000).toInt) = + Sample(fraction, withReplacement, seed, logicalPlan) + + def generate( + generator: Generator, + join: Boolean = false, + outer: Boolean = false, + alias: Option[String] = None) = + Generate(generator, join, outer, None, logicalPlan) + + def insertInto(tableName: String, overwrite: Boolean = false) = + InsertIntoTable( + analysis.UnresolvedRelation(None, tableName), Map.empty, logicalPlan, overwrite) + + def analyze = analysis.SimpleAnalyzer(logicalPlan) + } + + object plans { // scalastyle:ignore + implicit class DslLogicalPlan(val logicalPlan: LogicalPlan) extends LogicalPlanFunctions { + def writeToFile(path: String) = WriteToFile(path, logicalPlan) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala new file mode 100644 index 000000000..c253587f6 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst + +import trees._ + +/** + * Functions for attaching and retrieving trees that are associated with errors. + */ +package object errors { + + class TreeNodeException[TreeType <: TreeNode[_]] + (tree: TreeType, msg: String, cause: Throwable) extends Exception(msg, cause) { + + // Yes, this is the same as a default parameter, but... those don't seem to work with SBT + // external project dependencies for some reason. + def this(tree: TreeType, msg: String) = this(tree, msg, null) + + override def getMessage: String = { + val treeString = tree.toString + s"${super.getMessage}, tree:${if (treeString contains "\n") "\n" else " "}$tree" + } + } + + /** + * Wraps any exceptions that are thrown while executing `f` in a + * [[catalyst.errors.TreeNodeException TreeNodeException]], attaching the provided `tree`. + */ + def attachTree[TreeType <: TreeNode[_], A](tree: TreeType, msg: String = "")(f: => A): A = { + try f catch { + case e: Exception => throw new TreeNodeException(tree, msg, e) + } + } + + /** + * Executes `f` which is expected to throw a + * [[catalyst.errors.TreeNodeException TreeNodeException]]. The first tree encountered in + * the stack of exceptions of type `TreeType` is returned. + */ + def getTree[TreeType <: TreeNode[_]](f: => Unit): TreeType = ??? // TODO: Implement +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala new file mode 100644 index 000000000..3b6bac16f --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package expressions + +import rules._ +import errors._ + +import catalyst.plans.QueryPlan + +/** + * A bound reference points to a specific slot in the input tuple, allowing the actual value + * to be retrieved more efficiently. However, since operations like column pruning can change + * the layout of intermediate tuples, BindReferences should be run after all such transformations. + */ +case class BoundReference(ordinal: Int, baseReference: Attribute) + extends Attribute with trees.LeafNode[Expression] { + + type EvaluatedType = Any + + def nullable = baseReference.nullable + def dataType = baseReference.dataType + def exprId = baseReference.exprId + def qualifiers = baseReference.qualifiers + def name = baseReference.name + + def newInstance = BoundReference(ordinal, baseReference.newInstance) + def withQualifiers(newQualifiers: Seq[String]) = + BoundReference(ordinal, baseReference.withQualifiers(newQualifiers)) + + override def toString = s"$baseReference:$ordinal" + + override def apply(input: Row): Any = input(ordinal) +} + +class BindReferences[TreeNode <: QueryPlan[TreeNode]] extends Rule[TreeNode] { + import BindReferences._ + + def apply(plan: TreeNode): TreeNode = { + plan.transform { + case leafNode if leafNode.children.isEmpty => leafNode + case unaryNode if unaryNode.children.size == 1 => unaryNode.transformExpressions { case e => + bindReference(e, unaryNode.children.head.output) + } + } + } +} + +object BindReferences extends Logging { + def bindReference(expression: Expression, input: Seq[Attribute]): Expression = { + expression.transform { case a: AttributeReference => + attachTree(a, "Binding attribute") { + val ordinal = input.indexWhere(_.exprId == a.exprId) + if (ordinal == -1) { + // TODO: This fallback is required because some operators (such as ScriptTransform) + // produce new attributes that can't be bound. Likely the right thing to do is remove + // this rule and require all operators to explicitly bind to the input schema that + // they specify. + logger.debug(s"Couldn't find $a in ${input.mkString("[", ",", "]")}") + a + } else { + BoundReference(ordinal, a) + } + } + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala new file mode 100644 index 000000000..608656d3a --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package expressions + +import types._ + +/** Cast the child expression to the target data type. */ +case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { + override def foldable = child.foldable + def nullable = child.nullable + override def toString = s"CAST($child, $dataType)" + + type EvaluatedType = Any + + lazy val castingFunction: Any => Any = (child.dataType, dataType) match { + case (BinaryType, StringType) => a: Any => new String(a.asInstanceOf[Array[Byte]]) + case (StringType, BinaryType) => a: Any => a.asInstanceOf[String].getBytes + case (_, StringType) => a: Any => a.toString + case (StringType, IntegerType) => a: Any => castOrNull(a, _.toInt) + case (StringType, DoubleType) => a: Any => castOrNull(a, _.toDouble) + case (StringType, FloatType) => a: Any => castOrNull(a, _.toFloat) + case (StringType, LongType) => a: Any => castOrNull(a, _.toLong) + case (StringType, ShortType) => a: Any => castOrNull(a, _.toShort) + case (StringType, ByteType) => a: Any => castOrNull(a, _.toByte) + case (StringType, DecimalType) => a: Any => castOrNull(a, BigDecimal(_)) + case (BooleanType, ByteType) => a: Any => a match { + case null => null + case true => 1.toByte + case false => 0.toByte + } + case (dt, IntegerType) => + a: Any => dt.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].toInt(a) + case (dt, DoubleType) => + a: Any => dt.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].toDouble(a) + case (dt, FloatType) => + a: Any => dt.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].toFloat(a) + case (dt, LongType) => + a: Any => dt.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].toLong(a) + case (dt, ShortType) => + a: Any => dt.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].toInt(a).toShort + case (dt, ByteType) => + a: Any => dt.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].toInt(a).toByte + case (dt, DecimalType) => + a: Any => + BigDecimal(dt.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].toDouble(a)) + } + + @inline + protected def castOrNull[A](a: Any, f: String => A) = + try f(a.asInstanceOf[String]) catch { + case _: java.lang.NumberFormatException => null + } + + override def apply(input: Row): Any = { + val evaluated = child.apply(input) + if (evaluated == null) { + null + } else { + castingFunction(evaluated) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala new file mode 100644 index 000000000..78aaaeebb --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -0,0 +1,196 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package expressions + +import errors._ +import trees._ +import types._ + +abstract class Expression extends TreeNode[Expression] { + self: Product => + + /** The narrowest possible type that is produced when this expression is evaluated. */ + type EvaluatedType <: Any + + def dataType: DataType + + /** + * Returns true when an expression is a candidate for static evaluation before the query is + * executed. + * + * The following conditions are used to determine suitability for constant folding: + * - A [[expressions.Coalesce Coalesce]] is foldable if all of its children are foldable + * - A [[expressions.BinaryExpression BinaryExpression]] is foldable if its both left and right + * child are foldable + * - A [[expressions.Not Not]], [[expressions.IsNull IsNull]], or + * [[expressions.IsNotNull IsNotNull]] is foldable if its child is foldable. + * - A [[expressions.Literal]] is foldable. + * - A [[expressions.Cast Cast]] or [[expressions.UnaryMinus UnaryMinus]] is foldable if its + * child is foldable. + */ + // TODO: Supporting more foldable expressions. For example, deterministic Hive UDFs. + def foldable: Boolean = false + def nullable: Boolean + def references: Set[Attribute] + + /** Returns the result of evaluating this expression on a given input Row */ + def apply(input: Row = null): EvaluatedType = + throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") + + /** + * Returns `true` if this expression and all its children have been resolved to a specific schema + * and `false` if it is still contains any unresolved placeholders. Implementations of expressions + * should override this if the resolution of this type of expression involves more than just + * the resolution of its children. + */ + lazy val resolved: Boolean = childrenResolved + + /** + * Returns true if all the children of this expression have been resolved to a specific schema + * and false if any still contains any unresolved placeholders. + */ + def childrenResolved = !children.exists(!_.resolved) + + /** + * A set of helper functions that return the correct descendant of [[scala.math.Numeric]] type + * and do any casting necessary of child evaluation. + */ + @inline + def n1(e: Expression, i: Row, f: ((Numeric[Any], Any) => Any)): Any = { + val evalE = e.apply(i) + if (evalE == null) { + null + } else { + e.dataType match { + case n: NumericType => + val castedFunction = f.asInstanceOf[(Numeric[n.JvmType], n.JvmType) => n.JvmType] + castedFunction(n.numeric, evalE.asInstanceOf[n.JvmType]) + case other => sys.error(s"Type $other does not support numeric operations") + } + } + } + + @inline + protected final def n2( + i: Row, + e1: Expression, + e2: Expression, + f: ((Numeric[Any], Any, Any) => Any)): Any = { + + if (e1.dataType != e2.dataType) { + throw new TreeNodeException(this, s"Types do not match ${e1.dataType} != ${e2.dataType}") + } + + val evalE1 = e1.apply(i) + if(evalE1 == null) { + null + } else { + val evalE2 = e2.apply(i) + if (evalE2 == null) { + null + } else { + e1.dataType match { + case n: NumericType => + f.asInstanceOf[(Numeric[n.JvmType], n.JvmType, n.JvmType) => Int]( + n.numeric, evalE1.asInstanceOf[n.JvmType], evalE2.asInstanceOf[n.JvmType]) + case other => sys.error(s"Type $other does not support numeric operations") + } + } + } + } + + @inline + protected final def f2( + i: Row, + e1: Expression, + e2: Expression, + f: ((Fractional[Any], Any, Any) => Any)): Any = { + if (e1.dataType != e2.dataType) { + throw new TreeNodeException(this, s"Types do not match ${e1.dataType} != ${e2.dataType}") + } + + val evalE1 = e1.apply(i: Row) + if(evalE1 == null) { + null + } else { + val evalE2 = e2.apply(i: Row) + if (evalE2 == null) { + null + } else { + e1.dataType match { + case ft: FractionalType => + f.asInstanceOf[(Fractional[ft.JvmType], ft.JvmType, ft.JvmType) => ft.JvmType]( + ft.fractional, evalE1.asInstanceOf[ft.JvmType], evalE2.asInstanceOf[ft.JvmType]) + case other => sys.error(s"Type $other does not support fractional operations") + } + } + } + } + + @inline + protected final def i2( + i: Row, + e1: Expression, + e2: Expression, + f: ((Integral[Any], Any, Any) => Any)): Any = { + if (e1.dataType != e2.dataType) { + throw new TreeNodeException(this, s"Types do not match ${e1.dataType} != ${e2.dataType}") + } + + val evalE1 = e1.apply(i) + if(evalE1 == null) { + null + } else { + val evalE2 = e2.apply(i) + if (evalE2 == null) { + null + } else { + e1.dataType match { + case i: IntegralType => + f.asInstanceOf[(Integral[i.JvmType], i.JvmType, i.JvmType) => i.JvmType]( + i.integral, evalE1.asInstanceOf[i.JvmType], evalE2.asInstanceOf[i.JvmType]) + case other => sys.error(s"Type $other does not support numeric operations") + } + } + } + } +} + +abstract class BinaryExpression extends Expression with trees.BinaryNode[Expression] { + self: Product => + + def symbol: String + + override def foldable = left.foldable && right.foldable + + def references = left.references ++ right.references + + override def toString = s"($left $symbol $right)" +} + +abstract class LeafExpression extends Expression with trees.LeafNode[Expression] { + self: Product => +} + +abstract class UnaryExpression extends Expression with trees.UnaryNode[Expression] { + self: Product => + + def references = child.references +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala new file mode 100644 index 000000000..8c407d2fd --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst +package expressions + +/** + * Converts a [[Row]] to another Row given a sequence of expression that define each column of the + * new row. If the schema of the input row is specified, then the given expression will be bound to + * that schema. + */ +class Projection(expressions: Seq[Expression]) extends (Row => Row) { + def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) = + this(expressions.map(BindReferences.bindReference(_, inputSchema))) + + protected val exprArray = expressions.toArray + def apply(input: Row): Row = { + val outputArray = new Array[Any](exprArray.size) + var i = 0 + while (i < exprArray.size) { + outputArray(i) = exprArray(i).apply(input) + i += 1 + } + new GenericRow(outputArray) + } +} + +/** + * Converts a [[Row]] to another Row given a sequence of expression that define each column of th + * new row. If the schema of the input row is specified, then the given expression will be bound to + * that schema. + * + * In contrast to a normal projection, a MutableProjection reuses the same underlying row object + * each time an input row is added. This significatly reduces the cost of calcuating the + * projection, but means that it is not safe + */ +case class MutableProjection(expressions: Seq[Expression]) extends (Row => Row) { + def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) = + this(expressions.map(BindReferences.bindReference(_, inputSchema))) + + private[this] val exprArray = expressions.toArray + private[this] val mutableRow = new GenericMutableRow(exprArray.size) + def currentValue: Row = mutableRow + + def apply(input: Row): Row = { + var i = 0 + while (i < exprArray.size) { + mutableRow(i) = exprArray(i).apply(input) + i += 1 + } + mutableRow + } +} + +/** + * A mutable wrapper that makes two rows appear appear as a single concatenated row. Designed to + * be instantiated once per thread and reused. + */ +class JoinedRow extends Row { + private[this] var row1: Row = _ + private[this] var row2: Row = _ + + /** Updates this JoinedRow to used point at two new base rows. Returns itself. */ + def apply(r1: Row, r2: Row): Row = { + row1 = r1 + row2 = r2 + this + } + + def iterator = row1.iterator ++ row2.iterator + + def length = row1.length + row2.length + + def apply(i: Int) = + if (i < row1.size) row1(i) else row2(i - row1.size) + + def isNullAt(i: Int) = apply(i) == null + + def getInt(i: Int): Int = + if (i < row1.size) row1.getInt(i) else row2.getInt(i - row1.size) + + def getLong(i: Int): Long = + if (i < row1.size) row1.getLong(i) else row2.getLong(i - row1.size) + + def getDouble(i: Int): Double = + if (i < row1.size) row1.getDouble(i) else row2.getDouble(i - row1.size) + + def getBoolean(i: Int): Boolean = + if (i < row1.size) row1.getBoolean(i) else row2.getBoolean(i - row1.size) + + def getShort(i: Int): Short = + if (i < row1.size) row1.getShort(i) else row2.getShort(i - row1.size) + + def getByte(i: Int): Byte = + if (i < row1.size) row1.getByte(i) else row2.getByte(i - row1.size) + + def getFloat(i: Int): Float = + if (i < row1.size) row1.getFloat(i) else row2.getFloat(i - row1.size) + + def getString(i: Int): String = + if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size) + + def copy() = { + val totalSize = row1.size + row2.size + val copiedValues = new Array[Any](totalSize) + var i = 0 + while(i < totalSize) { + copiedValues(i) = apply(i) + i += 1 + } + new GenericRow(copiedValues) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala new file mode 100644 index 000000000..a5d0ecf96 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package expressions + +import types.DoubleType + +case object Rand extends LeafExpression { + def dataType = DoubleType + def nullable = false + def references = Set.empty + override def toString = "RAND()" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala new file mode 100644 index 000000000..352967546 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala @@ -0,0 +1,214 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package expressions + +import types._ + +/** + * Represents one row of output from a relational operator. Allows both generic access by ordinal, + * which will incur boxing overhead for primitives, as well as native primitive access. + * + * It is invalid to use the native primitive interface to retrieve a value that is null, instead a + * user must check [[isNullAt]] before attempting to retrieve a value that might be null. + */ +trait Row extends Seq[Any] with Serializable { + def apply(i: Int): Any + + def isNullAt(i: Int): Boolean + + def getInt(i: Int): Int + def getLong(i: Int): Long + def getDouble(i: Int): Double + def getFloat(i: Int): Float + def getBoolean(i: Int): Boolean + def getShort(i: Int): Short + def getByte(i: Int): Byte + def getString(i: Int): String + + override def toString() = + s"[${this.mkString(",")}]" + + def copy(): Row +} + +/** + * An extended interface to [[Row]] that allows the values for each column to be updated. Setting + * a value through a primitive function implicitly marks that column as not null. + */ +trait MutableRow extends Row { + def setNullAt(i: Int): Unit + + def update(ordinal: Int, value: Any) + + def setInt(ordinal: Int, value: Int) + def setLong(ordinal: Int, value: Long) + def setDouble(ordinal: Int, value: Double) + def setBoolean(ordinal: Int, value: Boolean) + def setShort(ordinal: Int, value: Short) + def setByte(ordinal: Int, value: Byte) + def setFloat(ordinal: Int, value: Float) + def setString(ordinal: Int, value: String) + + /** + * EXPERIMENTAL + * + * Returns a mutable string builder for the specified column. A given row should return the + * result of any mutations made to the returned buffer next time getString is called for the same + * column. + */ + def getStringBuilder(ordinal: Int): StringBuilder +} + +/** + * A row with no data. Calling any methods will result in an error. Can be used as a placeholder. + */ +object EmptyRow extends Row { + def apply(i: Int): Any = throw new UnsupportedOperationException + + def iterator = Iterator.empty + def length = 0 + def isNullAt(i: Int): Boolean = throw new UnsupportedOperationException + + def getInt(i: Int): Int = throw new UnsupportedOperationException + def getLong(i: Int): Long = throw new UnsupportedOperationException + def getDouble(i: Int): Double = throw new UnsupportedOperationException + def getFloat(i: Int): Float = throw new UnsupportedOperationException + def getBoolean(i: Int): Boolean = throw new UnsupportedOperationException + def getShort(i: Int): Short = throw new UnsupportedOperationException + def getByte(i: Int): Byte = throw new UnsupportedOperationException + def getString(i: Int): String = throw new UnsupportedOperationException + + def copy() = this +} + +/** + * A row implementation that uses an array of objects as the underlying storage. Note that, while + * the array is not copied, and thus could technically be mutated after creation, this is not + * allowed. + */ +class GenericRow(protected[catalyst] val values: Array[Any]) extends Row { + /** No-arg constructor for serialization. */ + def this() = this(null) + + def this(size: Int) = this(new Array[Any](size)) + + def iterator = values.iterator + + def length = values.length + + def apply(i: Int) = values(i) + + def isNullAt(i: Int) = values(i) == null + + def getInt(i: Int): Int = { + if (values(i) == null) sys.error("Failed to check null bit for primitive int value.") + values(i).asInstanceOf[Int] + } + + def getLong(i: Int): Long = { + if (values(i) == null) sys.error("Failed to check null bit for primitive long value.") + values(i).asInstanceOf[Long] + } + + def getDouble(i: Int): Double = { + if (values(i) == null) sys.error("Failed to check null bit for primitive double value.") + values(i).asInstanceOf[Double] + } + + def getFloat(i: Int): Float = { + if (values(i) == null) sys.error("Failed to check null bit for primitive float value.") + values(i).asInstanceOf[Float] + } + + def getBoolean(i: Int): Boolean = { + if (values(i) == null) sys.error("Failed to check null bit for primitive boolean value.") + values(i).asInstanceOf[Boolean] + } + + def getShort(i: Int): Short = { + if (values(i) == null) sys.error("Failed to check null bit for primitive short value.") + values(i).asInstanceOf[Short] + } + + def getByte(i: Int): Byte = { + if (values(i) == null) sys.error("Failed to check null bit for primitive byte value.") + values(i).asInstanceOf[Byte] + } + + def getString(i: Int): String = { + if (values(i) == null) sys.error("Failed to check null bit for primitive String value.") + values(i).asInstanceOf[String] + } + + def copy() = this +} + +class GenericMutableRow(size: Int) extends GenericRow(size) with MutableRow { + /** No-arg constructor for serialization. */ + def this() = this(0) + + def getStringBuilder(ordinal: Int): StringBuilder = ??? + + override def setBoolean(ordinal: Int,value: Boolean): Unit = { values(ordinal) = value } + override def setByte(ordinal: Int,value: Byte): Unit = { values(ordinal) = value } + override def setDouble(ordinal: Int,value: Double): Unit = { values(ordinal) = value } + override def setFloat(ordinal: Int,value: Float): Unit = { values(ordinal) = value } + override def setInt(ordinal: Int,value: Int): Unit = { values(ordinal) = value } + override def setLong(ordinal: Int,value: Long): Unit = { values(ordinal) = value } + override def setString(ordinal: Int,value: String): Unit = { values(ordinal) = value } + + override def setNullAt(i: Int): Unit = { values(i) = null } + + override def setShort(ordinal: Int,value: Short): Unit = { values(ordinal) = value } + + override def update(ordinal: Int,value: Any): Unit = { values(ordinal) = value } + + override def copy() = new GenericRow(values.clone()) +} + + +class RowOrdering(ordering: Seq[SortOrder]) extends Ordering[Row] { + def compare(a: Row, b: Row): Int = { + var i = 0 + while (i < ordering.size) { + val order = ordering(i) + val left = order.child.apply(a) + val right = order.child.apply(b) + + if (left == null && right == null) { + // Both null, continue looking. + } else if (left == null) { + return if (order.direction == Ascending) -1 else 1 + } else if (right == null) { + return if (order.direction == Ascending) 1 else -1 + } else { + val comparison = order.dataType match { + case n: NativeType if order.direction == Ascending => + n.ordering.asInstanceOf[Ordering[Any]].compare(left, right) + case n: NativeType if order.direction == Descending => + n.ordering.asInstanceOf[Ordering[Any]].reverse.compare(left, right) + } + if (comparison != 0) return comparison + } + i += 1 + } + return 0 + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala new file mode 100644 index 000000000..a3c7ca1ac --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package expressions + +import types._ + +case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expression]) + extends Expression { + + type EvaluatedType = Any + + def references = children.flatMap(_.references).toSet + def nullable = true + + override def apply(input: Row): Any = { + children.size match { + case 1 => function.asInstanceOf[(Any) => Any](children(0).apply(input)) + case 2 => + function.asInstanceOf[(Any, Any) => Any]( + children(0).apply(input), + children(1).apply(input)) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala new file mode 100644 index 000000000..171997b90 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package expressions + +abstract sealed class SortDirection +case object Ascending extends SortDirection +case object Descending extends SortDirection + +/** + * An expression that can be used to sort a tuple. This class extends expression primarily so that + * transformations over expression will descend into its child. + */ +case class SortOrder(child: Expression, direction: SortDirection) extends UnaryExpression { + def dataType = child.dataType + def nullable = child.nullable + override def toString = s"$child ${if (direction == Ascending) "ASC" else "DESC"}" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala new file mode 100644 index 000000000..2ad8d6f31 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package expressions + +import scala.language.dynamics + +import types._ + +case object DynamicType extends DataType + +case class WrapDynamic(children: Seq[Attribute]) extends Expression { + type EvaluatedType = DynamicRow + + def nullable = false + def references = children.toSet + def dataType = DynamicType + + override def apply(input: Row): DynamicRow = input match { + // Avoid copy for generic rows. + case g: GenericRow => new DynamicRow(children, g.values) + case otherRowType => new DynamicRow(children, otherRowType.toArray) + } +} + +class DynamicRow(val schema: Seq[Attribute], values: Array[Any]) + extends GenericRow(values) with Dynamic { + + def selectDynamic(attributeName: String): String = { + val ordinal = schema.indexWhere(_.name == attributeName) + values(ordinal).toString + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala new file mode 100644 index 000000000..2287a849e --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -0,0 +1,265 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package expressions + +import catalyst.types._ + +abstract class AggregateExpression extends Expression { + self: Product => + + /** + * Creates a new instance that can be used to compute this aggregate expression for a group + * of input rows/ + */ + def newInstance: AggregateFunction +} + +/** + * Represents an aggregation that has been rewritten to be performed in two steps. + * + * @param finalEvaluation an aggregate expression that evaluates to same final result as the + * original aggregation. + * @param partialEvaluations A sequence of [[NamedExpression]]s that can be computed on partial + * data sets and are required to compute the `finalEvaluation`. + */ +case class SplitEvaluation( + finalEvaluation: Expression, + partialEvaluations: Seq[NamedExpression]) + +/** + * An [[AggregateExpression]] that can be partially computed without seeing all relevent tuples. + * These partial evaluations can then be combined to compute the actual answer. + */ +abstract class PartialAggregate extends AggregateExpression { + self: Product => + + /** + * Returns a [[SplitEvaluation]] that computes this aggregation using partial aggregation. + */ + def asPartial: SplitEvaluation +} + +/** + * A specific implementation of an aggregate function. Used to wrap a generic + * [[AggregateExpression]] with an algorithm that will be used to compute one specific result. + */ +abstract class AggregateFunction + extends AggregateExpression with Serializable with trees.LeafNode[Expression] { + self: Product => + + type EvaluatedType = Any + + /** Base should return the generic aggregate expression that this function is computing */ + val base: AggregateExpression + def references = base.references + def nullable = base.nullable + def dataType = base.dataType + + def update(input: Row): Unit + override def apply(input: Row): Any + + // Do we really need this? + def newInstance = makeCopy(productIterator.map { case a: AnyRef => a }.toArray) +} + +case class Count(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { + def references = child.references + def nullable = false + def dataType = IntegerType + override def toString = s"COUNT($child)" + + def asPartial: SplitEvaluation = { + val partialCount = Alias(Count(child), "PartialCount")() + SplitEvaluation(Sum(partialCount.toAttribute), partialCount :: Nil) + } + + override def newInstance = new CountFunction(child, this) +} + +case class CountDistinct(expressions: Seq[Expression]) extends AggregateExpression { + def children = expressions + def references = expressions.flatMap(_.references).toSet + def nullable = false + def dataType = IntegerType + override def toString = s"COUNT(DISTINCT ${expressions.mkString(",")}})" + override def newInstance = new CountDistinctFunction(expressions, this) +} + +case class Average(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { + def references = child.references + def nullable = false + def dataType = DoubleType + override def toString = s"AVG($child)" + + override def asPartial: SplitEvaluation = { + val partialSum = Alias(Sum(child), "PartialSum")() + val partialCount = Alias(Count(child), "PartialCount")() + val castedSum = Cast(Sum(partialSum.toAttribute), dataType) + val castedCount = Cast(Sum(partialCount.toAttribute), dataType) + + SplitEvaluation( + Divide(castedSum, castedCount), + partialCount :: partialSum :: Nil) + } + + override def newInstance = new AverageFunction(child, this) +} + +case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { + def references = child.references + def nullable = false + def dataType = child.dataType + override def toString = s"SUM($child)" + + override def asPartial: SplitEvaluation = { + val partialSum = Alias(Sum(child), "PartialSum")() + SplitEvaluation( + Sum(partialSum.toAttribute), + partialSum :: Nil) + } + + override def newInstance = new SumFunction(child, this) +} + +case class SumDistinct(child: Expression) + extends AggregateExpression with trees.UnaryNode[Expression] { + + def references = child.references + def nullable = false + def dataType = child.dataType + override def toString = s"SUM(DISTINCT $child)" + + override def newInstance = new SumDistinctFunction(child, this) +} + +case class First(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { + def references = child.references + def nullable = child.nullable + def dataType = child.dataType + override def toString = s"FIRST($child)" + + override def asPartial: SplitEvaluation = { + val partialFirst = Alias(First(child), "PartialFirst")() + SplitEvaluation( + First(partialFirst.toAttribute), + partialFirst :: Nil) + } + override def newInstance = new FirstFunction(child, this) +} + +case class AverageFunction(expr: Expression, base: AggregateExpression) + extends AggregateFunction { + + def this() = this(null, null) // Required for serialization. + + private var count: Long = _ + private val sum = MutableLiteral(Cast(Literal(0), expr.dataType).apply(EmptyRow)) + private val sumAsDouble = Cast(sum, DoubleType) + + + + private val addFunction = Add(sum, expr) + + override def apply(input: Row): Any = + sumAsDouble.apply(EmptyRow).asInstanceOf[Double] / count.toDouble + + def update(input: Row): Unit = { + count += 1 + sum.update(addFunction, input) + } +} + +case class CountFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { + def this() = this(null, null) // Required for serialization. + + var count: Int = _ + + def update(input: Row): Unit = { + val evaluatedExpr = expr.map(_.apply(input)) + if (evaluatedExpr.map(_ != null).reduceLeft(_ || _)) { + count += 1 + } + } + + override def apply(input: Row): Any = count +} + +case class SumFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { + def this() = this(null, null) // Required for serialization. + + private val sum = MutableLiteral(Cast(Literal(0), expr.dataType).apply(null)) + + private val addFunction = Add(sum, expr) + + def update(input: Row): Unit = { + sum.update(addFunction, input) + } + + override def apply(input: Row): Any = sum.apply(null) +} + +case class SumDistinctFunction(expr: Expression, base: AggregateExpression) + extends AggregateFunction { + + def this() = this(null, null) // Required for serialization. + + val seen = new scala.collection.mutable.HashSet[Any]() + + def update(input: Row): Unit = { + val evaluatedExpr = expr.apply(input) + if (evaluatedExpr != null) { + seen += evaluatedExpr + } + } + + override def apply(input: Row): Any = + seen.reduceLeft(base.dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus) +} + +case class CountDistinctFunction(expr: Seq[Expression], base: AggregateExpression) + extends AggregateFunction { + + def this() = this(null, null) // Required for serialization. + + val seen = new scala.collection.mutable.HashSet[Any]() + + def update(input: Row): Unit = { + val evaluatedExpr = expr.map(_.apply(input)) + if (evaluatedExpr.map(_ != null).reduceLeft(_ && _)) { + seen += evaluatedExpr + } + } + + override def apply(input: Row): Any = seen.size +} + +case class FirstFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { + def this() = this(null, null) // Required for serialization. + + var result: Any = null + + def update(input: Row): Unit = { + if (result == null) { + result = expr.apply(input) + } + } + + override def apply(input: Row): Any = result +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala new file mode 100644 index 000000000..db235645c --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package expressions + +import catalyst.analysis.UnresolvedException +import catalyst.types._ + +case class UnaryMinus(child: Expression) extends UnaryExpression { + type EvaluatedType = Any + + def dataType = child.dataType + override def foldable = child.foldable + def nullable = child.nullable + override def toString = s"-$child" + + override def apply(input: Row): Any = { + n1(child, input, _.negate(_)) + } +} + +abstract class BinaryArithmetic extends BinaryExpression { + self: Product => + + type EvaluatedType = Any + + def nullable = left.nullable || right.nullable + + override lazy val resolved = + left.resolved && right.resolved && left.dataType == right.dataType + + def dataType = { + if (!resolved) { + throw new UnresolvedException(this, + s"datatype. Can not resolve due to differing types ${left.dataType}, ${right.dataType}") + } + left.dataType + } +} + +case class Add(left: Expression, right: Expression) extends BinaryArithmetic { + def symbol = "+" + + override def apply(input: Row): Any = n2(input, left, right, _.plus(_, _)) +} + +case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic { + def symbol = "-" + + override def apply(input: Row): Any = n2(input, left, right, _.minus(_, _)) +} + +case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic { + def symbol = "*" + + override def apply(input: Row): Any = n2(input, left, right, _.times(_, _)) +} + +case class Divide(left: Expression, right: Expression) extends BinaryArithmetic { + def symbol = "/" + + override def apply(input: Row): Any = dataType match { + case _: FractionalType => f2(input, left, right, _.div(_, _)) + case _: IntegralType => i2(input, left , right, _.quot(_, _)) + } + +} + +case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic { + def symbol = "%" + + override def apply(input: Row): Any = i2(input, left, right, _.rem(_, _)) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala new file mode 100644 index 000000000..d3feb6c46 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package expressions + +import types._ + +/** + * Returns the item at `ordinal` in the Array `child` or the Key `ordinal` in Map `child`. + */ +case class GetItem(child: Expression, ordinal: Expression) extends Expression { + type EvaluatedType = Any + + val children = child :: ordinal :: Nil + /** `Null` is returned for invalid ordinals. */ + override def nullable = true + override def references = children.flatMap(_.references).toSet + def dataType = child.dataType match { + case ArrayType(dt) => dt + case MapType(_, vt) => vt + } + override lazy val resolved = + childrenResolved && + (child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType]) + + override def toString = s"$child[$ordinal]" + + override def apply(input: Row): Any = { + if (child.dataType.isInstanceOf[ArrayType]) { + val baseValue = child.apply(input).asInstanceOf[Seq[_]] + val o = ordinal.apply(input).asInstanceOf[Int] + if (baseValue == null) { + null + } else if (o >= baseValue.size || o < 0) { + null + } else { + baseValue(o) + } + } else { + val baseValue = child.apply(input).asInstanceOf[Map[Any, _]] + val key = ordinal.apply(input) + if (baseValue == null) { + null + } else { + baseValue.get(key).orNull + } + } + } +} + +/** + * Returns the value of fields in the Struct `child`. + */ +case class GetField(child: Expression, fieldName: String) extends UnaryExpression { + type EvaluatedType = Any + + def dataType = field.dataType + def nullable = field.nullable + + protected def structType = child.dataType match { + case s: StructType => s + case otherType => sys.error(s"GetField is not valid on fields of type $otherType") + } + + lazy val field = + structType.fields + .find(_.name == fieldName) + .getOrElse(sys.error(s"No such field $fieldName in ${child.dataType}")) + + lazy val ordinal = structType.fields.indexOf(field) + + override lazy val resolved = childrenResolved && child.dataType.isInstanceOf[StructType] + + override def apply(input: Row): Any = { + val baseValue = child.apply(input).asInstanceOf[Row] + if (baseValue == null) null else baseValue(ordinal) + } + + override def toString = s"$child.$fieldName" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala new file mode 100644 index 000000000..c367de2a3 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package expressions + +import catalyst.types._ + +/** + * An expression that produces zero or more rows given a single input row. + * + * Generators produce multiple output rows instead of a single value like other expressions, + * and thus they must have a schema to associate with the rows that are output. + * + * However, unlike row producing relational operators, which are either leaves or determine their + * output schema functionally from their input, generators can contain other expressions that + * might result in their modification by rules. This structure means that they might be copied + * multiple times after first determining their output schema. If a new output schema is created for + * each copy references up the tree might be rendered invalid. As a result generators must + * instead define a function `makeOutput` which is called only once when the schema is first + * requested. The attributes produced by this function will be automatically copied anytime rules + * result in changes to the Generator or its children. + */ +abstract class Generator extends Expression with (Row => TraversableOnce[Row]) { + self: Product => + + type EvaluatedType = TraversableOnce[Row] + + lazy val dataType = + ArrayType(StructType(output.map(a => StructField(a.name, a.dataType, a.nullable)))) + + def nullable = false + + def references = children.flatMap(_.references).toSet + + /** + * Should be overridden by specific generators. Called only once for each instance to ensure + * that rule application does not change the output schema of a generator. + */ + protected def makeOutput(): Seq[Attribute] + + private var _output: Seq[Attribute] = null + + def output: Seq[Attribute] = { + if (_output == null) { + _output = makeOutput() + } + _output + } + + /** Should be implemented by child classes to perform specific Generators. */ + def apply(input: Row): TraversableOnce[Row] + + /** Overridden `makeCopy` also copies the attributes that are produced by this generator. */ + override def makeCopy(newArgs: Array[AnyRef]): this.type = { + val copy = super.makeCopy(newArgs) + copy._output = _output + copy + } +} + +/** + * Given an input array produces a sequence of rows for each value in the array. + */ +case class Explode(attributeNames: Seq[String], child: Expression) + extends Generator with trees.UnaryNode[Expression] { + + override lazy val resolved = + child.resolved && + (child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType]) + + lazy val elementTypes = child.dataType match { + case ArrayType(et) => et :: Nil + case MapType(kt,vt) => kt :: vt :: Nil + } + + // TODO: Move this pattern into Generator. + protected def makeOutput() = + if (attributeNames.size == elementTypes.size) { + attributeNames.zip(elementTypes).map { + case (n, t) => AttributeReference(n, t, nullable = true)() + } + } else { + elementTypes.zipWithIndex.map { + case (t, i) => AttributeReference(s"c_$i", t, nullable = true)() + } + } + + override def apply(input: Row): TraversableOnce[Row] = { + child.dataType match { + case ArrayType(_) => + val inputArray = child.apply(input).asInstanceOf[Seq[Any]] + if (inputArray == null) Nil else inputArray.map(v => new GenericRow(Array(v))) + case MapType(_, _) => + val inputMap = child.apply(input).asInstanceOf[Map[Any,Any]] + if (inputMap == null) Nil else inputMap.map { case (k,v) => new GenericRow(Array(k,v)) } + } + } + + override def toString() = s"explode($child)" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala new file mode 100644 index 000000000..229d8f7f7 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package expressions + +import types._ + +object Literal { + def apply(v: Any): Literal = v match { + case i: Int => Literal(i, IntegerType) + case l: Long => Literal(l, LongType) + case d: Double => Literal(d, DoubleType) + case f: Float => Literal(f, FloatType) + case b: Byte => Literal(b, ByteType) + case s: Short => Literal(s, ShortType) + case s: String => Literal(s, StringType) + case b: Boolean => Literal(b, BooleanType) + case null => Literal(null, NullType) + } +} + +/** + * Extractor for retrieving Int literals. + */ +object IntegerLiteral { + def unapply(a: Any): Option[Int] = a match { + case Literal(a: Int, IntegerType) => Some(a) + case _ => None + } +} + +case class Literal(value: Any, dataType: DataType) extends LeafExpression { + + override def foldable = true + def nullable = value == null + def references = Set.empty + + override def toString = if (value != null) value.toString else "null" + + type EvaluatedType = Any + override def apply(input: Row):Any = value +} + +// TODO: Specialize +case class MutableLiteral(var value: Any, nullable: Boolean = true) extends LeafExpression { + type EvaluatedType = Any + + val dataType = Literal(value).dataType + + def references = Set.empty + + def update(expression: Expression, input: Row) = { + value = expression.apply(input) + } + + override def apply(input: Row) = value +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala new file mode 100644 index 000000000..0a06e8532 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package expressions + +import catalyst.analysis.UnresolvedAttribute +import types._ + +object NamedExpression { + private val curId = new java.util.concurrent.atomic.AtomicLong() + def newExprId = ExprId(curId.getAndIncrement()) +} + +/** + * A globally (within this JVM) id for a given named expression. + * Used to identify with attribute output by a relation is being + * referenced in a subsuqent computation. + */ +case class ExprId(id: Long) + +abstract class NamedExpression extends Expression { + self: Product => + + def name: String + def exprId: ExprId + def qualifiers: Seq[String] + + def toAttribute: Attribute + + protected def typeSuffix = + if (resolved) { + dataType match { + case LongType => "L" + case _ => "" + } + } else { + "" + } +} + +abstract class Attribute extends NamedExpression { + self: Product => + + def withQualifiers(newQualifiers: Seq[String]): Attribute + + def references = Set(this) + def toAttribute = this + def newInstance: Attribute +} + +/** + * Used to assign a new name to a computation. + * For example the SQL expression "1 + 1 AS a" could be represented as follows: + * Alias(Add(Literal(1), Literal(1), "a")() + * + * @param child the computation being performed + * @param name the name to be associated with the result of computing [[child]]. + * @param exprId A globally unique id used to check if an [[AttributeReference]] refers to this + * alias. Auto-assigned if left blank. + */ +case class Alias(child: Expression, name: String) + (val exprId: ExprId = NamedExpression.newExprId, val qualifiers: Seq[String] = Nil) + extends NamedExpression with trees.UnaryNode[Expression] { + + type EvaluatedType = Any + + override def apply(input: Row) = child.apply(input) + + def dataType = child.dataType + def nullable = child.nullable + def references = child.references + + def toAttribute = { + if (resolved) { + AttributeReference(name, child.dataType, child.nullable)(exprId, qualifiers) + } else { + UnresolvedAttribute(name) + } + } + + override def toString: String = s"$child AS $name#${exprId.id}$typeSuffix" + + override protected final def otherCopyArgs = exprId :: qualifiers :: Nil +} + +/** + * A reference to an attribute produced by another operator in the tree. + * + * @param name The name of this attribute, should only be used during analysis or for debugging. + * @param dataType The [[types.DataType DataType]] of this attribute. + * @param nullable True if null is a valid value for this attribute. + * @param exprId A globally unique id used to check if different AttributeReferences refer to the + * same attribute. + * @param qualifiers a list of strings that can be used to referred to this attribute in a fully + * qualified way. Consider the examples tableName.name, subQueryAlias.name. + * tableName and subQueryAlias are possible qualifiers. + */ +case class AttributeReference(name: String, dataType: DataType, nullable: Boolean = true) + (val exprId: ExprId = NamedExpression.newExprId, val qualifiers: Seq[String] = Nil) + extends Attribute with trees.LeafNode[Expression] { + + override def equals(other: Any) = other match { + case ar: AttributeReference => exprId == ar.exprId && dataType == ar.dataType + case _ => false + } + + override def hashCode: Int = { + // See http://stackoverflow.com/questions/113511/hash-code-implementation + var h = 17 + h = h * 37 + exprId.hashCode() + h = h * 37 + dataType.hashCode() + h + } + + def newInstance = AttributeReference(name, dataType, nullable)(qualifiers = qualifiers) + + /** + * Returns a copy of this [[AttributeReference]] with changed nullability. + */ + def withNullability(newNullability: Boolean) = { + if (nullable == newNullability) { + this + } else { + AttributeReference(name, dataType, newNullability)(exprId, qualifiers) + } + } + + /** + * Returns a copy of this [[AttributeReference]] with new qualifiers. + */ + def withQualifiers(newQualifiers: Seq[String]) = { + if (newQualifiers == qualifiers) { + this + } else { + AttributeReference(name, dataType, nullable)(exprId, newQualifiers) + } + } + + override def toString: String = s"$name#${exprId.id}$typeSuffix" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala new file mode 100644 index 000000000..e869a4d9b --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package expressions + +import catalyst.analysis.UnresolvedException + +case class Coalesce(children: Seq[Expression]) extends Expression { + type EvaluatedType = Any + + /** Coalesce is nullable if all of its children are nullable, or if it has no children. */ + def nullable = !children.exists(!_.nullable) + + def references = children.flatMap(_.references).toSet + // Coalesce is foldable if all children are foldable. + override def foldable = !children.exists(!_.foldable) + + // Only resolved if all the children are of the same type. + override lazy val resolved = childrenResolved && (children.map(_.dataType).distinct.size == 1) + + override def toString = s"Coalesce(${children.mkString(",")})" + + def dataType = if (resolved) { + children.head.dataType + } else { + throw new UnresolvedException(this, "Coalesce cannot have children of different types.") + } + + override def apply(input: Row): Any = { + var i = 0 + var result: Any = null + while(i < children.size && result == null) { + result = children(i).apply(input) + i += 1 + } + result + } +} + +case class IsNull(child: Expression) extends Predicate with trees.UnaryNode[Expression] { + def references = child.references + override def foldable = child.foldable + def nullable = false + + override def apply(input: Row): Any = { + child.apply(input) == null + } +} + +case class IsNotNull(child: Expression) extends Predicate with trees.UnaryNode[Expression] { + def references = child.references + override def foldable = child.foldable + def nullable = false + override def toString = s"IS NOT NULL $child" + + override def apply(input: Row): Any = { + child.apply(input) != null + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala new file mode 100644 index 000000000..76554e160 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst + +/** + * A set of classes that can be used to represent trees of relational expressions. A key goal of + * the expression library is to hide the details of naming and scoping from developers who want to + * manipulate trees of relational operators. As such, the library defines a special type of + * expression, a [[NamedExpression]] in addition to the standard collection of expressions. + * + * ==Standard Expressions== + * A library of standard expressions (e.g., [[Add]], [[Equals]]), aggregates (e.g., SUM, COUNT), + * and other computations (e.g. UDFs). Each expression type is capable of determining its output + * schema as a function of its children's output schema. + * + * ==Named Expressions== + * Some expression are named and thus can be referenced by later operators in the dataflow graph. + * The two types of named expressions are [[AttributeReference]]s and [[Alias]]es. + * [[AttributeReference]]s refer to attributes of the input tuple for a given operator and form + * the leaves of some expression trees. Aliases assign a name to intermediate computations. + * For example, in the SQL statement `SELECT a+b AS c FROM ...`, the expressions `a` and `b` would + * be represented by `AttributeReferences` and `c` would be represented by an `Alias`. + * + * During [[analysis]], all named expressions are assigned a globally unique expression id, which + * can be used for equality comparisons. While the original names are kept around for debugging + * purposes, they should never be used to check if two attributes refer to the same value, as + * plan transformations can result in the introduction of naming ambiguity. For example, consider + * a plan that contains subqueries, both of which are reading from the same table. If an + * optimization removes the subqueries, scoping information would be destroyed, eliminating the + * ability to reason about which subquery produced a given attribute. + * + * ==Evaluation== + * The result of expressions can be evaluated using the [[Evaluate]] object. + */ +package object expressions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala new file mode 100644 index 000000000..561396eb4 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -0,0 +1,213 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package expressions + +import types._ +import catalyst.analysis.UnresolvedException + +trait Predicate extends Expression { + self: Product => + + def dataType = BooleanType + + type EvaluatedType = Any +} + +trait PredicateHelper { + def splitConjunctivePredicates(condition: Expression): Seq[Expression] = condition match { + case And(cond1, cond2) => splitConjunctivePredicates(cond1) ++ splitConjunctivePredicates(cond2) + case other => other :: Nil + } +} + +abstract class BinaryPredicate extends BinaryExpression with Predicate { + self: Product => + def nullable = left.nullable || right.nullable +} + +case class Not(child: Expression) extends Predicate with trees.UnaryNode[Expression] { + def references = child.references + override def foldable = child.foldable + def nullable = child.nullable + override def toString = s"NOT $child" + + override def apply(input: Row): Any = { + child.apply(input) match { + case null => null + case b: Boolean => !b + } + } +} + +/** + * Evaluates to `true` if `list` contains `value`. + */ +case class In(value: Expression, list: Seq[Expression]) extends Predicate { + def children = value +: list + def references = children.flatMap(_.references).toSet + def nullable = true // TODO: Figure out correct nullability semantics of IN. + override def toString = s"$value IN ${list.mkString("(", ",", ")")}" + + override def apply(input: Row): Any = { + val evaluatedValue = value.apply(input) + list.exists(e => e.apply(input) == evaluatedValue) + } +} + +case class And(left: Expression, right: Expression) extends BinaryPredicate { + def symbol = "&&" + + override def apply(input: Row): Any = { + val l = left.apply(input) + val r = right.apply(input) + if (l == false || r == false) { + false + } else if (l == null || r == null ) { + null + } else { + true + } + } +} + +case class Or(left: Expression, right: Expression) extends BinaryPredicate { + def symbol = "||" + + override def apply(input: Row): Any = { + val l = left.apply(input) + val r = right.apply(input) + if (l == true || r == true) { + true + } else if (l == null || r == null) { + null + } else { + false + } + } +} + +abstract class BinaryComparison extends BinaryPredicate { + self: Product => +} + +case class Equals(left: Expression, right: Expression) extends BinaryComparison { + def symbol = "=" + override def apply(input: Row): Any = { + val l = left.apply(input) + val r = right.apply(input) + if (l == null || r == null) null else l == r + } +} + +case class LessThan(left: Expression, right: Expression) extends BinaryComparison { + def symbol = "<" + override def apply(input: Row): Any = { + if (left.dataType == StringType && right.dataType == StringType) { + val l = left.apply(input) + val r = right.apply(input) + if(l == null || r == null) { + null + } else { + l.asInstanceOf[String] < r.asInstanceOf[String] + } + } else { + n2(input, left, right, _.lt(_, _)) + } + } +} + +case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryComparison { + def symbol = "<=" + override def apply(input: Row): Any = { + if (left.dataType == StringType && right.dataType == StringType) { + val l = left.apply(input) + val r = right.apply(input) + if(l == null || r == null) { + null + } else { + l.asInstanceOf[String] <= r.asInstanceOf[String] + } + } else { + n2(input, left, right, _.lteq(_, _)) + } + } +} + +case class GreaterThan(left: Expression, right: Expression) extends BinaryComparison { + def symbol = ">" + override def apply(input: Row): Any = { + if (left.dataType == StringType && right.dataType == StringType) { + val l = left.apply(input) + val r = right.apply(input) + if(l == null || r == null) { + null + } else { + l.asInstanceOf[String] > r.asInstanceOf[String] + } + } else { + n2(input, left, right, _.gt(_, _)) + } + } +} + +case class GreaterThanOrEqual(left: Expression, right: Expression) extends BinaryComparison { + def symbol = ">=" + override def apply(input: Row): Any = { + if (left.dataType == StringType && right.dataType == StringType) { + val l = left.apply(input) + val r = right.apply(input) + if(l == null || r == null) { + null + } else { + l.asInstanceOf[String] >= r.asInstanceOf[String] + } + } else { + n2(input, left, right, _.gteq(_, _)) + } + } +} + +case class If(predicate: Expression, trueValue: Expression, falseValue: Expression) + extends Expression { + + def children = predicate :: trueValue :: falseValue :: Nil + def nullable = trueValue.nullable || falseValue.nullable + def references = children.flatMap(_.references).toSet + override lazy val resolved = childrenResolved && trueValue.dataType == falseValue.dataType + def dataType = { + if (!resolved) { + throw new UnresolvedException( + this, + s"Can not resolve due to differing types ${trueValue.dataType}, ${falseValue.dataType}") + } + trueValue.dataType + } + + type EvaluatedType = Any + override def apply(input: Row): Any = { + if (predicate(input).asInstanceOf[Boolean]) { + trueValue.apply(input) + } else { + falseValue.apply(input) + } + } + + override def toString = s"if ($predicate) $trueValue else $falseValue" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala new file mode 100644 index 000000000..6e585236b --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package expressions + +import catalyst.types.BooleanType + +case class Like(left: Expression, right: Expression) extends BinaryExpression { + def dataType = BooleanType + def nullable = left.nullable // Right cannot be null. + def symbol = "LIKE" +} + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala new file mode 100644 index 000000000..4db280317 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package optimizer + +import catalyst.expressions._ +import catalyst.plans.logical._ +import catalyst.rules._ +import catalyst.types.BooleanType +import catalyst.plans.Inner + +object Optimizer extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Subqueries", Once, + EliminateSubqueries) :: + Batch("ConstantFolding", Once, + ConstantFolding, + BooleanSimplification, + SimplifyCasts) :: + Batch("Filter Pushdown", Once, + EliminateSubqueries, + CombineFilters, + PushPredicateThroughProject, + PushPredicateThroughInnerJoin) :: Nil +} + +/** + * Removes [[catalyst.plans.logical.Subquery Subquery]] operators from the plan. Subqueries are + * only required to provide scoping information for attributes and can be removed once analysis is + * complete. + */ +object EliminateSubqueries extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case Subquery(_, child) => child + } +} + +/** + * Replaces [[catalyst.expressions.Expression Expressions]] that can be statically evaluated with + * equivalent [[catalyst.expressions.Literal Literal]] values. + */ +object ConstantFolding extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case q: LogicalPlan => q transformExpressionsDown { + // Skip redundant folding of literals. + case l: Literal => l + case e if e.foldable => Literal(e.apply(null), e.dataType) + } + } +} + +/** + * Simplifies boolean expressions where the answer can be determined without evaluating both sides. + * Note that this rule can eliminate expressions that might otherwise have been evaluated and thus + * is only safe when evaluations of expressions does not result in side effects. + */ +object BooleanSimplification extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case q: LogicalPlan => q transformExpressionsUp { + case and @ And(left, right) => { + (left, right) match { + case (Literal(true, BooleanType), r) => r + case (l, Literal(true, BooleanType)) => l + case (Literal(false, BooleanType), _) => Literal(false) + case (_, Literal(false, BooleanType)) => Literal(false) + case (_, _) => and + } + } + case or @ Or(left, right) => { + (left, right) match { + case (Literal(true, BooleanType), _) => Literal(true) + case (_, Literal(true, BooleanType)) => Literal(true) + case (Literal(false, BooleanType), r) => r + case (l, Literal(false, BooleanType)) => l + case (_, _) => or + } + } + } + } +} + +/** + * Combines two adjacent [[catalyst.plans.logical.Filter Filter]] operators into one, merging the + * conditions into one conjunctive predicate. + */ +object CombineFilters extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case ff@Filter(fc, nf@Filter(nc, grandChild)) => Filter(And(nc, fc), grandChild) + } +} + +/** + * Pushes [[catalyst.plans.logical.Filter Filter]] operators through + * [[catalyst.plans.logical.Project Project]] operators, in-lining any + * [[catalyst.expressions.Alias Aliases]] that were defined in the projection. + * + * This heuristic is valid assuming the expression evaluation cost is minimal. + */ +object PushPredicateThroughProject extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case filter@Filter(condition, project@Project(fields, grandChild)) => + val sourceAliases = fields.collect { case a@Alias(c, _) => a.toAttribute -> c }.toMap + project.copy(child = filter.copy( + replaceAlias(condition, sourceAliases), + grandChild)) + } + + // + def replaceAlias(condition: Expression, sourceAliases: Map[Attribute, Expression]): Expression = { + condition transform { + case a: AttributeReference => sourceAliases.getOrElse(a, a) + } + } +} + +/** + * Pushes down [[catalyst.plans.logical.Filter Filter]] operators where the `condition` can be + * evaluated using only the attributes of the left or right side of an inner join. Other + * [[catalyst.plans.logical.Filter Filter]] conditions are moved into the `condition` of the + * [[catalyst.plans.logical.Join Join]]. + */ +object PushPredicateThroughInnerJoin extends Rule[LogicalPlan] with PredicateHelper { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case f @ Filter(filterCondition, Join(left, right, Inner, joinCondition)) => + val allConditions = + splitConjunctivePredicates(filterCondition) ++ + joinCondition.map(splitConjunctivePredicates).getOrElse(Nil) + + // Split the predicates into those that can be evaluated on the left, right, and those that + // must be evaluated after the join. + val (rightConditions, leftOrJoinConditions) = + allConditions.partition(_.references subsetOf right.outputSet) + val (leftConditions, joinConditions) = + leftOrJoinConditions.partition(_.references subsetOf left.outputSet) + + // Build the new left and right side, optionally with the pushed down filters. + val newLeft = leftConditions.reduceLeftOption(And).map(Filter(_, left)).getOrElse(left) + val newRight = rightConditions.reduceLeftOption(And).map(Filter(_, right)).getOrElse(right) + Join(newLeft, newRight, Inner, joinConditions.reduceLeftOption(And)) + } +} + +/** + * Removes [[catalyst.expressions.Cast Casts]] that are unnecessary because the input is already + * the correct type. + */ +object SimplifyCasts extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + case Cast(e, dataType) if e.dataType == dataType => e + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala new file mode 100644 index 000000000..22f8ea005 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package planning + + +import plans.logical.LogicalPlan +import trees._ + +/** + * Abstract class for transforming [[plans.logical.LogicalPlan LogicalPlan]]s into physical plans. + * Child classes are responsible for specifying a list of [[Strategy]] objects that each of which + * can return a list of possible physical plan options. If a given strategy is unable to plan all + * of the remaining operators in the tree, it can call [[planLater]], which returns a placeholder + * object that will be filled in using other available strategies. + * + * TODO: RIGHT NOW ONLY ONE PLAN IS RETURNED EVER... + * PLAN SPACE EXPLORATION WILL BE IMPLEMENTED LATER. + * + * @tparam PhysicalPlan The type of physical plan produced by this [[QueryPlanner]] + */ +abstract class QueryPlanner[PhysicalPlan <: TreeNode[PhysicalPlan]] { + /** A list of execution strategies that can be used by the planner */ + def strategies: Seq[Strategy] + + /** + * Given a [[plans.logical.LogicalPlan LogicalPlan]], returns a list of `PhysicalPlan`s that can + * be used for execution. If this strategy does not apply to the give logical operation then an + * empty list should be returned. + */ + abstract protected class Strategy extends Logging { + def apply(plan: LogicalPlan): Seq[PhysicalPlan] + } + + /** + * Returns a placeholder for a physical plan that executes `plan`. This placeholder will be + * filled in automatically by the QueryPlanner using the other execution strategies that are + * available. + */ + protected def planLater(plan: LogicalPlan) = apply(plan).next() + + def apply(plan: LogicalPlan): Iterator[PhysicalPlan] = { + // Obviously a lot to do here still... + val iter = strategies.view.flatMap(_(plan)).toIterator + assert(iter.hasNext, s"No plan for $plan") + iter + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/package.scala new file mode 100644 index 000000000..64370ec7c --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/package.scala @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst + +/** + * Contains classes for enumerating possible physical plans for a given logical query plan. + */ +package object planning diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala new file mode 100644 index 000000000..613b028ca --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package planning + +import scala.annotation.tailrec + +import expressions._ +import plans.logical._ + +/** + * A pattern that matches any number of filter operations on top of another relational operator. + * Adjacent filter operators are collected and their conditions are broken up and returned as a + * sequence of conjunctive predicates. + * + * @return A tuple containing a sequence of conjunctive predicates that should be used to filter the + * output and a relational operator. + */ +object FilteredOperation extends PredicateHelper { + type ReturnType = (Seq[Expression], LogicalPlan) + + def unapply(plan: LogicalPlan): Option[ReturnType] = Some(collectFilters(Nil, plan)) + + @tailrec + private def collectFilters(filters: Seq[Expression], plan: LogicalPlan): ReturnType = plan match { + case Filter(condition, child) => + collectFilters(filters ++ splitConjunctivePredicates(condition), child) + case other => (filters, other) + } +} + +/** + * A pattern that matches any number of project or filter operations on top of another relational + * operator. All filter operators are collected and their conditions are broken up and returned + * together with the top project operator. [[Alias Aliases]] are in-lined/substituted if necessary. + */ +object PhysicalOperation extends PredicateHelper { + type ReturnType = (Seq[NamedExpression], Seq[Expression], LogicalPlan) + + def unapply(plan: LogicalPlan): Option[ReturnType] = { + val (fields, filters, child, _) = collectProjectsAndFilters(plan) + Some((fields.getOrElse(child.output), filters, child)) + } + + /** + * Collects projects and filters, in-lining/substituting aliases if necessary. Here are two + * examples for alias in-lining/substitution. Before: + * {{{ + * SELECT c1 FROM (SELECT key AS c1 FROM t1) t2 WHERE c1 > 10 + * SELECT c1 AS c2 FROM (SELECT key AS c1 FROM t1) t2 WHERE c1 > 10 + * }}} + * After: + * {{{ + * SELECT key AS c1 FROM t1 WHERE key > 10 + * SELECT key AS c2 FROM t1 WHERE key > 10 + * }}} + */ + def collectProjectsAndFilters(plan: LogicalPlan): + (Option[Seq[NamedExpression]], Seq[Expression], LogicalPlan, Map[Attribute, Expression]) = + plan match { + case Project(fields, child) => + val (_, filters, other, aliases) = collectProjectsAndFilters(child) + val substitutedFields = fields.map(substitute(aliases)).asInstanceOf[Seq[NamedExpression]] + (Some(substitutedFields), filters, other, collectAliases(substitutedFields)) + + case Filter(condition, child) => + val (fields, filters, other, aliases) = collectProjectsAndFilters(child) + val substitutedCondition = substitute(aliases)(condition) + (fields, filters ++ splitConjunctivePredicates(substitutedCondition), other, aliases) + + case other => + (None, Nil, other, Map.empty) + } + + def collectAliases(fields: Seq[Expression]) = fields.collect { + case a @ Alias(child, _) => a.toAttribute.asInstanceOf[Attribute] -> child + }.toMap + + def substitute(aliases: Map[Attribute, Expression])(expr: Expression) = expr.transform { + case a @ Alias(ref: AttributeReference, name) => + aliases.get(ref).map(Alias(_, name)(a.exprId, a.qualifiers)).getOrElse(a) + + case a: AttributeReference => + aliases.get(a).map(Alias(_, a.name)(a.exprId, a.qualifiers)).getOrElse(a) + } +} + +/** + * A pattern that collects all adjacent unions and returns their children as a Seq. + */ +object Unions { + def unapply(plan: LogicalPlan): Option[Seq[LogicalPlan]] = plan match { + case u: Union => Some(collectUnionChildren(u)) + case _ => None + } + + private def collectUnionChildren(plan: LogicalPlan): Seq[LogicalPlan] = plan match { + case Union(l, r) => collectUnionChildren(l) ++ collectUnionChildren(r) + case other => other :: Nil + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala new file mode 100644 index 000000000..20f230c5c --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package plans + +import catalyst.expressions.{SortOrder, Attribute, Expression} +import catalyst.trees._ + +abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanType] { + self: PlanType with Product => + + def output: Seq[Attribute] + + /** + * Returns the set of attributes that are output by this node. + */ + def outputSet: Set[Attribute] = output.toSet + + /** + * Runs [[transform]] with `rule` on all expressions present in this query operator. + * Users should not expect a specific directionality. If a specific directionality is needed, + * transformExpressionsDown or transformExpressionsUp should be used. + * @param rule the rule to be applied to every expression in this operator. + */ + def transformExpressions(rule: PartialFunction[Expression, Expression]): this.type = { + transformExpressionsDown(rule) + } + + /** + * Runs [[transformDown]] with `rule` on all expressions present in this query operator. + * @param rule the rule to be applied to every expression in this operator. + */ + def transformExpressionsDown(rule: PartialFunction[Expression, Expression]): this.type = { + var changed = false + + @inline def transformExpressionDown(e: Expression) = { + val newE = e.transformDown(rule) + if (newE.id != e.id && newE != e) { + changed = true + newE + } else { + e + } + } + + val newArgs = productIterator.map { + case e: Expression => transformExpressionDown(e) + case Some(e: Expression) => Some(transformExpressionDown(e)) + case m: Map[_,_] => m + case seq: Traversable[_] => seq.map { + case e: Expression => transformExpressionDown(e) + case other => other + } + case other: AnyRef => other + }.toArray + + if (changed) makeCopy(newArgs) else this + } + + /** + * Runs [[transformUp]] with `rule` on all expressions present in this query operator. + * @param rule the rule to be applied to every expression in this operator. + * @return + */ + def transformExpressionsUp(rule: PartialFunction[Expression, Expression]): this.type = { + var changed = false + + @inline def transformExpressionUp(e: Expression) = { + val newE = e.transformUp(rule) + if (newE.id != e.id && newE != e) { + changed = true + newE + } else { + e + } + } + + val newArgs = productIterator.map { + case e: Expression => transformExpressionUp(e) + case Some(e: Expression) => Some(transformExpressionUp(e)) + case m: Map[_,_] => m + case seq: Traversable[_] => seq.map { + case e: Expression => transformExpressionUp(e) + case other => other + } + case other: AnyRef => other + }.toArray + + if (changed) makeCopy(newArgs) else this + } + + /** Returns the result of running [[transformExpressions]] on this node + * and all its children. */ + def transformAllExpressions(rule: PartialFunction[Expression, Expression]): this.type = { + transform { + case q: QueryPlan[_] => q.transformExpressions(rule).asInstanceOf[PlanType] + }.asInstanceOf[this.type] + } + + /** Returns all of the expressions present in this query plan operator. */ + def expressions: Seq[Expression] = { + productIterator.flatMap { + case e: Expression => e :: Nil + case Some(e: Expression) => e :: Nil + case seq: Traversable[_] => seq.flatMap { + case e: Expression => e :: Nil + case other => Nil + } + case other => Nil + }.toSeq + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala new file mode 100644 index 000000000..9f2283ad4 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package plans + +sealed abstract class JoinType +case object Inner extends JoinType +case object LeftOuter extends JoinType +case object RightOuter extends JoinType +case object FullOuter extends JoinType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/BaseRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/BaseRelation.scala new file mode 100644 index 000000000..48ff45c3d --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/BaseRelation.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package plans +package logical + +abstract class BaseRelation extends LeafNode { + self: Product => + + def tableName: String + def isPartitioned: Boolean = false +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala new file mode 100644 index 000000000..bc7b6871d --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package plans +package logical + +import catalyst.expressions._ +import catalyst.errors._ +import catalyst.types.StructType + +abstract class LogicalPlan extends QueryPlan[LogicalPlan] { + self: Product => + + /** + * Returns the set of attributes that are referenced by this node + * during evaluation. + */ + def references: Set[Attribute] + + /** + * Returns the set of attributes that this node takes as + * input from its children. + */ + lazy val inputSet: Set[Attribute] = children.flatMap(_.output).toSet + + /** + * Returns true if this expression and all its children have been resolved to a specific schema + * and false if it is still contains any unresolved placeholders. Implementations of LogicalPlan + * can override this (e.g. [[catalyst.analysis.UnresolvedRelation UnresolvedRelation]] should + * return `false`). + */ + lazy val resolved: Boolean = !expressions.exists(!_.resolved) && childrenResolved + + /** + * Returns true if all its children of this query plan have been resolved. + */ + def childrenResolved = !children.exists(!_.resolved) + + /** + * Optionally resolves the given string to a + * [[catalyst.expressions.NamedExpression NamedExpression]]. The attribute is expressed as + * as string in the following form: `[scope].AttributeName.[nested].[fields]...`. + */ + def resolve(name: String): Option[NamedExpression] = { + val parts = name.split("\\.") + // Collect all attributes that are output by this nodes children where either the first part + // matches the name or where the first part matches the scope and the second part matches the + // name. Return these matches along with any remaining parts, which represent dotted access to + // struct fields. + val options = children.flatMap(_.output).flatMap { option => + // If the first part of the desired name matches a qualifier for this possible match, drop it. + val remainingParts = if (option.qualifiers contains parts.head) parts.drop(1) else parts + if (option.name == remainingParts.head) (option, remainingParts.tail.toList) :: Nil else Nil + } + + options.distinct match { + case (a, Nil) :: Nil => Some(a) // One match, no nested fields, use it. + // One match, but we also need to extract the requested nested field. + case (a, nestedFields) :: Nil => + a.dataType match { + case StructType(fields) => + Some(Alias(nestedFields.foldLeft(a: Expression)(GetField), nestedFields.last)()) + case _ => None // Don't know how to resolve these field references + } + case Nil => None // No matches. + case ambiguousReferences => + throw new TreeNodeException( + this, s"Ambiguous references to $name: ${ambiguousReferences.mkString(",")}") + } + } +} + +/** + * A logical plan node with no children. + */ +abstract class LeafNode extends LogicalPlan with trees.LeafNode[LogicalPlan] { + self: Product => + + // Leaf nodes by definition cannot reference any input attributes. + def references = Set.empty +} + +/** + * A logical node that represents a non-query command to be executed by the system. For example, + * commands can be used by parsers to represent DDL operations. + */ +abstract class Command extends LeafNode { + self: Product => + def output = Seq.empty +} + +/** + * Returned for commands supported by a given parser, but not catalyst. In general these are DDL + * commands that are passed directly to another system. + */ +case class NativeCommand(cmd: String) extends Command + +/** + * Returned by a parser when the users only wants to see what query plan would be executed, without + * actually performing the execution. + */ +case class ExplainCommand(plan: LogicalPlan) extends Command + +/** + * A logical plan node with single child. + */ +abstract class UnaryNode extends LogicalPlan with trees.UnaryNode[LogicalPlan] { + self: Product => +} + +/** + * A logical plan node with a left and right child. + */ +abstract class BinaryNode extends LogicalPlan with trees.BinaryNode[LogicalPlan] { + self: Product => +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala new file mode 100644 index 000000000..1a1a2b9b8 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package plans +package logical + +import expressions._ + +/** + * Transforms the input by forking and running the specified script. + * + * @param input the set of expression that should be passed to the script. + * @param script the command that should be executed. + * @param output the attributes that are produced by the script. + */ +case class ScriptTransformation( + input: Seq[Expression], + script: String, + output: Seq[Attribute], + child: LogicalPlan) extends UnaryNode { + def references = input.flatMap(_.references).toSet +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TestRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TestRelation.scala new file mode 100644 index 000000000..b5905a445 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TestRelation.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package plans +package logical + +import expressions._ +import rules._ + +object LocalRelation { + def apply(output: Attribute*) = + new LocalRelation(output) +} + +case class LocalRelation(output: Seq[Attribute], data: Seq[Product] = Nil) + extends LeafNode with analysis.MultiInstanceRelation { + + // TODO: Validate schema compliance. + def loadData(newData: Seq[Product]) = new LocalRelation(output, data ++ newData) + + /** + * Returns an identical copy of this relation with new exprIds for all attributes. Different + * attributes are required when a relation is going to be included multiple times in the same + * query. + */ + override final def newInstance: this.type = { + LocalRelation(output.map(_.newInstance), data).asInstanceOf[this.type] + } + + override protected def stringArgs = Iterator(output) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala new file mode 100644 index 000000000..8e98aab73 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package plans +package logical + +import expressions._ + +case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode { + def output = projectList.map(_.toAttribute) + def references = projectList.flatMap(_.references).toSet +} + +/** + * Applies a [[catalyst.expressions.Generator Generator]] to a stream of input rows, combining the + * output of each into a new stream of rows. This operation is similar to a `flatMap` in functional + * programming with one important additional feature, which allows the input rows to be joined with + * their output. + * @param join when true, each output row is implicitly joined with the input tuple that produced + * it. + * @param outer when true, each input row will be output at least once, even if the output of the + * given `generator` is empty. `outer` has no effect when `join` is false. + * @param alias when set, this string is applied to the schema of the output of the transformation + * as a qualifier. + */ +case class Generate( + generator: Generator, + join: Boolean, + outer: Boolean, + alias: Option[String], + child: LogicalPlan) + extends UnaryNode { + + protected def generatorOutput = + alias + .map(a => generator.output.map(_.withQualifiers(a :: Nil))) + .getOrElse(generator.output) + + def output = + if (join) child.output ++ generatorOutput else generatorOutput + + def references = + if (join) child.outputSet else generator.references +} + +case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode { + def output = child.output + def references = condition.references +} + +case class Union(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { + // TODO: These aren't really the same attributes as nullability etc might change. + def output = left.output + + override lazy val resolved = + childrenResolved && + !left.output.zip(right.output).exists { case (l,r) => l.dataType != r.dataType } + + def references = Set.empty +} + +case class Join( + left: LogicalPlan, + right: LogicalPlan, + joinType: JoinType, + condition: Option[Expression]) extends BinaryNode { + + def references = condition.map(_.references).getOrElse(Set.empty) + def output = left.output ++ right.output +} + +case class InsertIntoTable( + table: BaseRelation, + partition: Map[String, Option[String]], + child: LogicalPlan, + overwrite: Boolean) + extends LogicalPlan { + // The table being inserted into is a child for the purposes of transformations. + def children = table :: child :: Nil + def references = Set.empty + def output = child.output + + override lazy val resolved = childrenResolved && child.output.zip(table.output).forall { + case (childAttr, tableAttr) => childAttr.dataType == tableAttr.dataType + } +} + +case class InsertIntoCreatedTable( + databaseName: Option[String], + tableName: String, + child: LogicalPlan) extends UnaryNode { + def references = Set.empty + def output = child.output +} + +case class WriteToFile( + path: String, + child: LogicalPlan) extends UnaryNode { + def references = Set.empty + def output = child.output +} + +case class Sort(order: Seq[SortOrder], child: LogicalPlan) extends UnaryNode { + def output = child.output + def references = order.flatMap(_.references).toSet +} + +case class Aggregate( + groupingExpressions: Seq[Expression], + aggregateExpressions: Seq[NamedExpression], + child: LogicalPlan) + extends UnaryNode { + + def output = aggregateExpressions.map(_.toAttribute) + def references = child.references +} + +case class StopAfter(limit: Expression, child: LogicalPlan) extends UnaryNode { + def output = child.output + def references = limit.references +} + +case class Subquery(alias: String, child: LogicalPlan) extends UnaryNode { + def output = child.output.map(_.withQualifiers(alias :: Nil)) + def references = Set.empty +} + +case class Sample(fraction: Double, withReplacement: Boolean, seed: Int, child: LogicalPlan) + extends UnaryNode { + + def output = child.output + def references = Set.empty +} + +case class Distinct(child: LogicalPlan) extends UnaryNode { + def output = child.output + def references = child.outputSet +} + +case object NoRelation extends LeafNode { + def output = Nil +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala new file mode 100644 index 000000000..f7fcdc5fd --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package plans +package logical + +import expressions._ + +/** + * Performs a physical redistribution of the data. Used when the consumer of the query + * result have expectations about the distribution and ordering of partitioned input data. + */ +abstract class RedistributeData extends UnaryNode { + self: Product => + + def output = child.output +} + +case class SortPartitions(sortExpressions: Seq[SortOrder], child: LogicalPlan) + extends RedistributeData { + + def references = sortExpressions.flatMap(_.references).toSet +} + +case class Repartition(partitionExpressions: Seq[Expression], child: LogicalPlan) + extends RedistributeData { + + def references = partitionExpressions.flatMap(_.references).toSet +} + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/package.scala new file mode 100644 index 000000000..a40ab4bbb --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/package.scala @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst + +/** + * A a collection of common abstractions for query plans as well as + * a base logical plan representation. + */ +package object plans diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala new file mode 100644 index 000000000..2d8f3ad33 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -0,0 +1,201 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package plans +package physical + +import expressions._ +import types._ + +/** + * Specifies how tuples that share common expressions will be distributed when a query is executed + * in parallel on many machines. Distribution can be used to refer to two distinct physical + * properties: + * - Inter-node partitioning of data: In this case the distribution describes how tuples are + * partitioned across physical machines in a cluster. Knowing this property allows some + * operators (e.g., Aggregate) to perform partition local operations instead of global ones. + * - Intra-partition ordering of data: In this case the distribution describes guarantees made + * about how tuples are distributed within a single partition. + */ +sealed trait Distribution + +/** + * Represents a distribution where no promises are made about co-location of data. + */ +case object UnspecifiedDistribution extends Distribution + +/** + * Represents a distribution that only has a single partition and all tuples of the dataset + * are co-located. + */ +case object AllTuples extends Distribution + +/** + * Represents data where tuples that share the same values for the `clustering` + * [[catalyst.expressions.Expression Expressions]] will be co-located. Based on the context, this + * can mean such tuples are either co-located in the same partition or they will be contiguous + * within a single partition. + */ +case class ClusteredDistribution(clustering: Seq[Expression]) extends Distribution { + require( + clustering != Nil, + "The clustering expressions of a ClusteredDistribution should not be Nil. " + + "An AllTuples should be used to represent a distribution that only has " + + "a single partition.") +} + +/** + * Represents data where tuples have been ordered according to the `ordering` + * [[catalyst.expressions.Expression Expressions]]. This is a strictly stronger guarantee than + * [[ClusteredDistribution]] as an ordering will ensure that tuples that share the same value for + * the ordering expressions are contiguous and will never be split across partitions. + */ +case class OrderedDistribution(ordering: Seq[SortOrder]) extends Distribution { + require( + ordering != Nil, + "The ordering expressions of a OrderedDistribution should not be Nil. " + + "An AllTuples should be used to represent a distribution that only has " + + "a single partition.") + + def clustering = ordering.map(_.child).toSet +} + +sealed trait Partitioning { + /** Returns the number of partitions that the data is split across */ + val numPartitions: Int + + /** + * Returns true iff the guarantees made by this + * [[catalyst.plans.physical.Partitioning Partitioning]] are sufficient to satisfy + * the partitioning scheme mandated by the `required` + * [[catalyst.plans.physical.Distribution Distribution]], i.e. the current dataset does not + * need to be re-partitioned for the `required` Distribution (it is possible that tuples within + * a partition need to be reorganized). + */ + def satisfies(required: Distribution): Boolean + + /** + * Returns true iff all distribution guarantees made by this partitioning can also be made + * for the `other` specified partitioning. + * For example, two [[catalyst.plans.physical.HashPartitioning HashPartitioning]]s are + * only compatible if the `numPartitions` of them is the same. + */ + def compatibleWith(other: Partitioning): Boolean +} + +case class UnknownPartitioning(numPartitions: Int) extends Partitioning { + override def satisfies(required: Distribution): Boolean = required match { + case UnspecifiedDistribution => true + case _ => false + } + + override def compatibleWith(other: Partitioning): Boolean = other match { + case UnknownPartitioning(_) => true + case _ => false + } +} + +case object SinglePartition extends Partitioning { + val numPartitions = 1 + + override def satisfies(required: Distribution): Boolean = true + + override def compatibleWith(other: Partitioning) = other match { + case SinglePartition => true + case _ => false + } +} + +case object BroadcastPartitioning extends Partitioning { + val numPartitions = 1 + + override def satisfies(required: Distribution): Boolean = true + + override def compatibleWith(other: Partitioning) = other match { + case SinglePartition => true + case _ => false + } +} + +/** + * Represents a partitioning where rows are split up across partitions based on the hash + * of `expressions`. All rows where `expressions` evaluate to the same values are guaranteed to be + * in the same partition. + */ +case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) + extends Expression + with Partitioning { + + def children = expressions + def references = expressions.flatMap(_.references).toSet + def nullable = false + def dataType = IntegerType + + lazy val clusteringSet = expressions.toSet + + override def satisfies(required: Distribution): Boolean = required match { + case UnspecifiedDistribution => true + case ClusteredDistribution(requiredClustering) => + clusteringSet.subsetOf(requiredClustering.toSet) + case _ => false + } + + override def compatibleWith(other: Partitioning) = other match { + case BroadcastPartitioning => true + case h: HashPartitioning if h == this => true + case _ => false + } +} + +/** + * Represents a partitioning where rows are split across partitions based on some total ordering of + * the expressions specified in `ordering`. When data is partitioned in this manner the following + * two conditions are guaranteed to hold: + * - All row where the expressions in `ordering` evaluate to the same values will be in the same + * partition. + * - Each partition will have a `min` and `max` row, relative to the given ordering. All rows + * that are in between `min` and `max` in this `ordering` will reside in this partition. + */ +case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) + extends Expression + with Partitioning { + + def children = ordering + def references = ordering.flatMap(_.references).toSet + def nullable = false + def dataType = IntegerType + + lazy val clusteringSet = ordering.map(_.child).toSet + + override def satisfies(required: Distribution): Boolean = required match { + case UnspecifiedDistribution => true + case OrderedDistribution(requiredOrdering) => + val minSize = Seq(requiredOrdering.size, ordering.size).min + requiredOrdering.take(minSize) == ordering.take(minSize) + case ClusteredDistribution(requiredClustering) => + clusteringSet.subsetOf(requiredClustering.toSet) + case _ => false + } + + override def compatibleWith(other: Partitioning) = other match { + case BroadcastPartitioning => true + case r: RangePartitioning if r == this => true + case _ => false + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/Rule.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/Rule.scala new file mode 100644 index 000000000..6ff4891a3 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/Rule.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package rules + +import trees._ + +abstract class Rule[TreeType <: TreeNode[_]] extends Logging { + + /** Name for this rule, automatically inferred based on class name. */ + val ruleName: String = { + val className = getClass.getName + if (className endsWith "$") className.dropRight(1) else className + } + + def apply(plan: TreeType): TreeType +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala new file mode 100644 index 000000000..68ae30cde --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package rules + +import trees._ +import util._ + +abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { + + /** + * An execution strategy for rules that indicates the maximum number of executions. If the + * execution reaches fix point (i.e. converge) before maxIterations, it will stop. + */ + abstract class Strategy { def maxIterations: Int } + + /** A strategy that only runs once. */ + case object Once extends Strategy { val maxIterations = 1 } + + /** A strategy that runs until fix point or maxIterations times, whichever comes first. */ + case class FixedPoint(maxIterations: Int) extends Strategy + + /** A batch of rules. */ + protected case class Batch(name: String, strategy: Strategy, rules: Rule[TreeType]*) + + /** Defines a sequence of rule batches, to be overridden by the implementation. */ + protected val batches: Seq[Batch] + + /** + * Executes the batches of rules defined by the subclass. The batches are executed serially + * using the defined execution strategy. Within each batch, rules are also executed serially. + */ + def apply(plan: TreeType): TreeType = { + var curPlan = plan + + batches.foreach { batch => + var iteration = 1 + var lastPlan = curPlan + curPlan = batch.rules.foldLeft(curPlan) { case (curPlan, rule) => rule(curPlan) } + + // Run until fix point (or the max number of iterations as specified in the strategy. + while (iteration < batch.strategy.maxIterations && !curPlan.fastEquals(lastPlan)) { + lastPlan = curPlan + curPlan = batch.rules.foldLeft(curPlan) { + case (curPlan, rule) => + val result = rule(curPlan) + if (!result.fastEquals(curPlan)) { + logger.debug( + s""" + |=== Applying Rule ${rule.ruleName} === + |${sideBySide(curPlan.treeString, result.treeString).mkString("\n")} + """.stripMargin) + } + + result + } + iteration += 1 + } + } + + curPlan + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/package.scala new file mode 100644 index 000000000..26ab54308 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/package.scala @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst + +/** + * A framework for applying batches rewrite rules to trees, possibly to fixed point. + */ +package object rules diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala new file mode 100644 index 000000000..76ede87e4 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -0,0 +1,364 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package trees + +import errors._ + +object TreeNode { + private val currentId = new java.util.concurrent.atomic.AtomicLong + protected def nextId() = currentId.getAndIncrement() +} + +/** Used by [[TreeNode.getNodeNumbered]] when traversing the tree for a given number */ +private class MutableInt(var i: Int) + +abstract class TreeNode[BaseType <: TreeNode[BaseType]] { + self: BaseType with Product => + + /** Returns a Seq of the children of this node */ + def children: Seq[BaseType] + + /** + * A globally unique id for this specific instance. Not preserved across copies. + * Unlike `equals`, `id` can be used to differentiate distinct but structurally + * identical branches of a tree. + */ + val id = TreeNode.nextId() + + /** + * Returns true if other is the same [[catalyst.trees.TreeNode TreeNode]] instance. Unlike + * `equals` this function will return false for different instances of structurally identical + * trees. + */ + def sameInstance(other: TreeNode[_]): Boolean = { + this.id == other.id + } + + /** + * Faster version of equality which short-circuits when two treeNodes are the same instance. + * We don't just override Object.Equals, as doing so prevents the scala compiler from from + * generating case class `equals` methods + */ + def fastEquals(other: TreeNode[_]): Boolean = { + sameInstance(other) || this == other + } + + /** + * Runs the given function on this node and then recursively on [[children]]. + * @param f the function to be applied to each node in the tree. + */ + def foreach(f: BaseType => Unit): Unit = { + f(this) + children.foreach(_.foreach(f)) + } + + /** + * Returns a Seq containing the result of applying the given function to each + * node in this tree in a preorder traversal. + * @param f the function to be applied. + */ + def map[A](f: BaseType => A): Seq[A] = { + val ret = new collection.mutable.ArrayBuffer[A]() + foreach(ret += f(_)) + ret + } + + /** + * Returns a Seq by applying a function to all nodes in this tree and using the elements of the + * resulting collections. + */ + def flatMap[A](f: BaseType => TraversableOnce[A]): Seq[A] = { + val ret = new collection.mutable.ArrayBuffer[A]() + foreach(ret ++= f(_)) + ret + } + + /** + * Returns a Seq containing the result of applying a partial function to all elements in this + * tree on which the function is defined. + */ + def collect[B](pf: PartialFunction[BaseType, B]): Seq[B] = { + val ret = new collection.mutable.ArrayBuffer[B]() + val lifted = pf.lift + foreach(node => lifted(node).foreach(ret.+=)) + ret + } + + /** + * Returns a copy of this node where `f` has been applied to all the nodes children. + */ + def mapChildren(f: BaseType => BaseType): this.type = { + var changed = false + val newArgs = productIterator.map { + case arg: TreeNode[_] if children contains arg => + val newChild = f(arg.asInstanceOf[BaseType]) + if (newChild fastEquals arg) { + arg + } else { + changed = true + newChild + } + case nonChild: AnyRef => nonChild + case null => null + }.toArray + if (changed) makeCopy(newArgs) else this + } + + /** + * Returns a copy of this node with the children replaced. + * TODO: Validate somewhere (in debug mode?) that children are ordered correctly. + */ + def withNewChildren(newChildren: Seq[BaseType]): this.type = { + assert(newChildren.size == children.size, "Incorrect number of children") + var changed = false + val remainingNewChildren = newChildren.toBuffer + val remainingOldChildren = children.toBuffer + val newArgs = productIterator.map { + case arg: TreeNode[_] if children contains arg => + val newChild = remainingNewChildren.remove(0) + val oldChild = remainingOldChildren.remove(0) + if (newChild fastEquals oldChild) { + oldChild + } else { + changed = true + newChild + } + case nonChild: AnyRef => nonChild + case null => null + }.toArray + + if (changed) makeCopy(newArgs) else this + } + + /** + * Returns a copy of this node where `rule` has been recursively applied to the tree. + * When `rule` does not apply to a given node it is left unchanged. + * Users should not expect a specific directionality. If a specific directionality is needed, + * transformDown or transformUp should be used. + * @param rule the function use to transform this nodes children + */ + def transform(rule: PartialFunction[BaseType, BaseType]): BaseType = { + transformDown(rule) + } + + /** + * Returns a copy of this node where `rule` has been recursively applied to it and all of its + * children (pre-order). When `rule` does not apply to a given node it is left unchanged. + * @param rule the function used to transform this nodes children + */ + def transformDown(rule: PartialFunction[BaseType, BaseType]): BaseType = { + val afterRule = rule.applyOrElse(this, identity[BaseType]) + // Check if unchanged and then possibly return old copy to avoid gc churn. + if (this fastEquals afterRule) { + transformChildrenDown(rule) + } else { + afterRule.transformChildrenDown(rule) + } + } + + /** + * Returns a copy of this node where `rule` has been recursively applied to all the children of + * this node. When `rule` does not apply to a given node it is left unchanged. + * @param rule the function used to transform this nodes children + */ + def transformChildrenDown(rule: PartialFunction[BaseType, BaseType]): this.type = { + var changed = false + val newArgs = productIterator.map { + case arg: TreeNode[_] if children contains arg => + val newChild = arg.asInstanceOf[BaseType].transformDown(rule) + if (!(newChild fastEquals arg)) { + changed = true + newChild + } else { + arg + } + case m: Map[_,_] => m + case args: Traversable[_] => args.map { + case arg: TreeNode[_] if children contains arg => + val newChild = arg.asInstanceOf[BaseType].transformDown(rule) + if (!(newChild fastEquals arg)) { + changed = true + newChild + } else { + arg + } + case other => other + } + case nonChild: AnyRef => nonChild + case null => null + }.toArray + if (changed) makeCopy(newArgs) else this + } + + /** + * Returns a copy of this node where `rule` has been recursively applied first to all of its + * children and then itself (post-order). When `rule` does not apply to a given node, it is left + * unchanged. + * @param rule the function use to transform this nodes children + */ + def transformUp(rule: PartialFunction[BaseType, BaseType]): BaseType = { + val afterRuleOnChildren = transformChildrenUp(rule); + if (this fastEquals afterRuleOnChildren) { + rule.applyOrElse(this, identity[BaseType]) + } else { + rule.applyOrElse(afterRuleOnChildren, identity[BaseType]) + } + } + + def transformChildrenUp(rule: PartialFunction[BaseType, BaseType]): this.type = { + var changed = false + val newArgs = productIterator.map { + case arg: TreeNode[_] if children contains arg => + val newChild = arg.asInstanceOf[BaseType].transformUp(rule) + if (!(newChild fastEquals arg)) { + changed = true + newChild + } else { + arg + } + case m: Map[_,_] => m + case args: Traversable[_] => args.map { + case arg: TreeNode[_] if children contains arg => + val newChild = arg.asInstanceOf[BaseType].transformUp(rule) + if (!(newChild fastEquals arg)) { + changed = true + newChild + } else { + arg + } + case other => other + } + case nonChild: AnyRef => nonChild + case null => null + }.toArray + if (changed) makeCopy(newArgs) else this + } + + /** + * Args to the constructor that should be copied, but not transformed. + * These are appended to the transformed args automatically by makeCopy + * @return + */ + protected def otherCopyArgs: Seq[AnyRef] = Nil + + /** + * Creates a copy of this type of tree node after a transformation. + * Must be overridden by child classes that have constructor arguments + * that are not present in the productIterator. + * @param newArgs the new product arguments. + */ + def makeCopy(newArgs: Array[AnyRef]): this.type = attachTree(this, "makeCopy") { + try { + val defaultCtor = getClass.getConstructors.head + if (otherCopyArgs.isEmpty) { + defaultCtor.newInstance(newArgs: _*).asInstanceOf[this.type] + } else { + defaultCtor.newInstance((newArgs ++ otherCopyArgs).toArray: _*).asInstanceOf[this.type] + } + } catch { + case e: java.lang.IllegalArgumentException => + throw new TreeNodeException( + this, s"Failed to copy node. Is otherCopyArgs specified correctly for $nodeName?") + } + } + + /** Returns the name of this type of TreeNode. Defaults to the class name. */ + def nodeName = getClass.getSimpleName + + /** + * The arguments that should be included in the arg string. Defaults to the `productIterator`. + */ + protected def stringArgs = productIterator + + /** Returns a string representing the arguments to this node, minus any children */ + def argString: String = productIterator.flatMap { + case tn: TreeNode[_] if children contains tn => Nil + case tn: TreeNode[_] if tn.toString contains "\n" => s"(${tn.simpleString})" :: Nil + case seq: Seq[_] => seq.mkString("[", ",", "]") :: Nil + case set: Set[_] => set.mkString("{", ",", "}") :: Nil + case other => other :: Nil + }.mkString(", ") + + /** String representation of this node without any children */ + def simpleString = s"$nodeName $argString" + + override def toString: String = treeString + + /** Returns a string representation of the nodes in this tree */ + def treeString = generateTreeString(0, new StringBuilder).toString + + /** + * Returns a string representation of the nodes in this tree, where each operator is numbered. + * The numbers can be used with [[trees.TreeNode.apply apply]] to easily access specific subtrees. + */ + def numberedTreeString = + treeString.split("\n").zipWithIndex.map { case (line, i) => f"$i%02d $line" }.mkString("\n") + + /** + * Returns the tree node at the specified number. + * Numbers for each node can be found in the [[numberedTreeString]]. + */ + def apply(number: Int): BaseType = getNodeNumbered(new MutableInt(number)) + + protected def getNodeNumbered(number: MutableInt): BaseType = { + if (number.i < 0) { + null.asInstanceOf[BaseType] + } else if (number.i == 0) { + this + } else { + number.i -= 1 + children.map(_.getNodeNumbered(number)).find(_ != null).getOrElse(null.asInstanceOf[BaseType]) + } + } + + /** Appends the string represent of this node and its children to the given StringBuilder. */ + protected def generateTreeString(depth: Int, builder: StringBuilder): StringBuilder = { + builder.append(" " * depth) + builder.append(simpleString) + builder.append("\n") + children.foreach(_.generateTreeString(depth + 1, builder)) + builder + } +} + +/** + * A [[TreeNode]] that has two children, [[left]] and [[right]]. + */ +trait BinaryNode[BaseType <: TreeNode[BaseType]] { + def left: BaseType + def right: BaseType + + def children = Seq(left, right) +} + +/** + * A [[TreeNode]] with no children. + */ +trait LeafNode[BaseType <: TreeNode[BaseType]] { + def children = Nil +} + +/** + * A [[TreeNode]] with a single [[child]]. + */ +trait UnaryNode[BaseType <: TreeNode[BaseType]] { + def child: BaseType + def children = child :: Nil +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala new file mode 100644 index 000000000..e2da1d243 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst + +/** + * A library for easily manipulating trees of operators. Operators that extend TreeNode are + * granted the following interface: + *
      + *
    • Scala collection like methods (foreach, map, flatMap, collect, etc)
    • + *
    • + * transform - accepts a partial function that is used to generate a new tree. When the + * partial function can be applied to a given tree segment, that segment is replaced with the + * result. After attempting to apply the partial function to a given node, the transform + * function recursively attempts to apply the function to that node's children. + *
    • + *
    • debugging support - pretty printing, easy splicing of trees, etc.
    • + *
    + */ +package object trees { + // Since we want tree nodes to be lightweight, we create one logger for all treenode instances. + protected val logger = Logger("catalyst.trees") +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala new file mode 100644 index 000000000..6eb2b62ec --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package types + +import expressions.Expression + +abstract class DataType { + /** Matches any expression that evaluates to this DataType */ + def unapply(a: Expression): Boolean = a match { + case e: Expression if e.dataType == this => true + case _ => false + } +} + +case object NullType extends DataType + +abstract class NativeType extends DataType { + type JvmType + val ordering: Ordering[JvmType] +} + +case object StringType extends NativeType { + type JvmType = String + val ordering = implicitly[Ordering[JvmType]] +} +case object BinaryType extends DataType { + type JvmType = Array[Byte] +} +case object BooleanType extends NativeType { + type JvmType = Boolean + val ordering = implicitly[Ordering[JvmType]] +} + +abstract class NumericType extends NativeType { + // Unfortunately we can't get this implicitly as that breaks Spark Serialization. In order for + // implicitly[Numeric[JvmType]] to be valid, we have to change JvmType from a type variable to a + // type parameter and and add a numeric annotation (i.e., [JvmType : Numeric]). This gets + // desugared by the compiler into an argument to the objects constructor. This means there is no + // longer an no argument constructor and thus the JVM cannot serialize the object anymore. + val numeric: Numeric[JvmType] +} + +/** Matcher for any expressions that evaluate to [[IntegralType]]s */ +object IntegralType { + def unapply(a: Expression): Boolean = a match { + case e: Expression if e.dataType.isInstanceOf[IntegralType] => true + case _ => false + } +} + +abstract class IntegralType extends NumericType { + val integral: Integral[JvmType] +} + +case object LongType extends IntegralType { + type JvmType = Long + val numeric = implicitly[Numeric[Long]] + val integral = implicitly[Integral[Long]] + val ordering = implicitly[Ordering[JvmType]] +} + +case object IntegerType extends IntegralType { + type JvmType = Int + val numeric = implicitly[Numeric[Int]] + val integral = implicitly[Integral[Int]] + val ordering = implicitly[Ordering[JvmType]] +} + +case object ShortType extends IntegralType { + type JvmType = Short + val numeric = implicitly[Numeric[Short]] + val integral = implicitly[Integral[Short]] + val ordering = implicitly[Ordering[JvmType]] +} + +case object ByteType extends IntegralType { + type JvmType = Byte + val numeric = implicitly[Numeric[Byte]] + val integral = implicitly[Integral[Byte]] + val ordering = implicitly[Ordering[JvmType]] +} + +/** Matcher for any expressions that evaluate to [[FractionalType]]s */ +object FractionalType { + def unapply(a: Expression): Boolean = a match { + case e: Expression if e.dataType.isInstanceOf[FractionalType] => true + case _ => false + } +} +abstract class FractionalType extends NumericType { + val fractional: Fractional[JvmType] +} + +case object DecimalType extends FractionalType { + type JvmType = BigDecimal + val numeric = implicitly[Numeric[BigDecimal]] + val fractional = implicitly[Fractional[BigDecimal]] + val ordering = implicitly[Ordering[JvmType]] +} + +case object DoubleType extends FractionalType { + type JvmType = Double + val numeric = implicitly[Numeric[Double]] + val fractional = implicitly[Fractional[Double]] + val ordering = implicitly[Ordering[JvmType]] +} + +case object FloatType extends FractionalType { + type JvmType = Float + val numeric = implicitly[Numeric[Float]] + val fractional = implicitly[Fractional[Float]] + val ordering = implicitly[Ordering[JvmType]] +} + +case class ArrayType(elementType: DataType) extends DataType + +case class StructField(name: String, dataType: DataType, nullable: Boolean) +case class StructType(fields: Seq[StructField]) extends DataType + +case class MapType(keyType: DataType, valueType: DataType) extends DataType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/package.scala new file mode 100644 index 000000000..b65a5617d --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/package.scala @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +/** + * Contains a type system for attributes produced by relations, including complex types like + * structs, arrays and maps. + */ +package object types diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala new file mode 100644 index 000000000..52adea266 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst + +import java.io.{PrintWriter, ByteArrayOutputStream, FileInputStream, File} + +package object util { + /** + * Returns a path to a temporary file that probably does not exist. + * Note, there is always the race condition that someone created this + * file since the last time we checked. Thus, this shouldn't be used + * for anything security conscious. + */ + def getTempFilePath(prefix: String, suffix: String = ""): File = { + val tempFile = File.createTempFile(prefix, suffix) + tempFile.delete() + tempFile + } + + def fileToString(file: File, encoding: String = "UTF-8") = { + val inStream = new FileInputStream(file) + val outStream = new ByteArrayOutputStream + try { + var reading = true + while ( reading ) { + inStream.read() match { + case -1 => reading = false + case c => outStream.write(c) + } + } + outStream.flush() + } + finally { + inStream.close() + } + new String(outStream.toByteArray, encoding) + } + + def resourceToString( + resource:String, + encoding: String = "UTF-8", + classLoader: ClassLoader = this.getClass.getClassLoader) = { + val inStream = classLoader.getResourceAsStream(resource) + val outStream = new ByteArrayOutputStream + try { + var reading = true + while ( reading ) { + inStream.read() match { + case -1 => reading = false + case c => outStream.write(c) + } + } + outStream.flush() + } + finally { + inStream.close() + } + new String(outStream.toByteArray, encoding) + } + + def stringToFile(file: File, str: String): File = { + val out = new PrintWriter(file) + out.write(str) + out.close() + file + } + + def sideBySide(left: String, right: String): Seq[String] = { + sideBySide(left.split("\n"), right.split("\n")) + } + + def sideBySide(left: Seq[String], right: Seq[String]): Seq[String] = { + val maxLeftSize = left.map(_.size).max + val leftPadded = left ++ Seq.fill(math.max(right.size - left.size, 0))("") + val rightPadded = right ++ Seq.fill(math.max(left.size - right.size, 0))("") + + leftPadded.zip(rightPadded).map { + case (l, r) => (if (l == r) " " else "!") + l + (" " * ((maxLeftSize - l.size) + 3)) + r + } + } + + def stackTraceToString(t: Throwable): String = { + val out = new java.io.ByteArrayOutputStream + val writer = new PrintWriter(out) + t.printStackTrace(writer) + writer.flush() + new String(out.toByteArray) + } + + def stringOrNull(a: AnyRef) = if (a == null) null else a.toString + + def benchmark[A](f: => A): A = { + val startTime = System.nanoTime() + val ret = f + val endTime = System.nanoTime() + println(s"${(endTime - startTime).toDouble / 1000000}ms") + ret + } + + /* FIX ME + implicit class debugLogging(a: AnyRef) { + def debugLogging() { + org.apache.log4j.Logger.getLogger(a.getClass.getName).setLevel(org.apache.log4j.Level.DEBUG) + } + } */ +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/package.scala new file mode 100644 index 000000000..9ec31689b --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/package.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +/** + * Allows the execution of relational queries, including those expressed in SQL using Spark. + * + * Note that this package is located in catalyst instead of in core so that all subprojects can + * inherit the settings from this package object. + */ +package object sql { + + protected[sql] def Logger(name: String) = + com.typesafe.scalalogging.slf4j.Logger(org.slf4j.LoggerFactory.getLogger(name)) + + protected[sql] type Logging = com.typesafe.scalalogging.slf4j.Logging + + type Row = catalyst.expressions.Row + + object Row { + /** + * This method can be used to extract fields from a [[Row]] object in a pattern match. Example: + * {{{ + * import org.apache.spark.sql._ + * + * val pairs = sql("SELECT key, value FROM src").rdd.map { + * case Row(key: Int, value: String) => + * key -> value + * } + * }}} + */ + def unapplySeq(row: Row): Some[Seq[Any]] = Some(row) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/AnalysisSuite.scala new file mode 100644 index 000000000..1fd0d26b6 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/AnalysisSuite.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package analysis + +import org.scalatest.FunSuite + +import analysis._ +import expressions._ +import plans.logical._ +import types._ + +import dsl._ +import dsl.expressions._ + +class AnalysisSuite extends FunSuite { + val analyze = SimpleAnalyzer + + val testRelation = LocalRelation('a.int) + + test("analyze project") { + assert(analyze(Project(Seq(UnresolvedAttribute("a")), testRelation)) === Project(testRelation.output, testRelation)) + + } +} \ No newline at end of file diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala new file mode 100644 index 000000000..fb25e1c24 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst + +import org.scalatest.FunSuite + +import org.apache.spark.sql.catalyst.plans.physical._ + +/* Implicit conversions */ +import org.apache.spark.sql.catalyst.dsl.expressions._ + +class DistributionSuite extends FunSuite { + + protected def checkSatisfied( + inputPartitioning: Partitioning, + requiredDistribution: Distribution, + satisfied: Boolean) { + if (inputPartitioning.satisfies(requiredDistribution) != satisfied) + fail( + s""" + |== Input Partitioning == + |$inputPartitioning + |== Required Distribution == + |$requiredDistribution + |== Does input partitioning satisfy required distribution? == + |Expected $satisfied got ${inputPartitioning.satisfies(requiredDistribution)} + """.stripMargin) + } + + test("HashPartitioning is the output partitioning") { + // Cases which do not need an exchange between two data properties. + checkSatisfied( + HashPartitioning(Seq('a, 'b, 'c), 10), + UnspecifiedDistribution, + true) + + checkSatisfied( + HashPartitioning(Seq('a, 'b, 'c), 10), + ClusteredDistribution(Seq('a, 'b, 'c)), + true) + + checkSatisfied( + HashPartitioning(Seq('b, 'c), 10), + ClusteredDistribution(Seq('a, 'b, 'c)), + true) + + checkSatisfied( + SinglePartition, + ClusteredDistribution(Seq('a, 'b, 'c)), + true) + + checkSatisfied( + SinglePartition, + OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)), + true) + + // Cases which need an exchange between two data properties. + checkSatisfied( + HashPartitioning(Seq('a, 'b, 'c), 10), + ClusteredDistribution(Seq('b, 'c)), + false) + + checkSatisfied( + HashPartitioning(Seq('a, 'b, 'c), 10), + ClusteredDistribution(Seq('d, 'e)), + false) + + checkSatisfied( + HashPartitioning(Seq('a, 'b, 'c), 10), + AllTuples, + false) + + checkSatisfied( + HashPartitioning(Seq('a, 'b, 'c), 10), + OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)), + false) + + checkSatisfied( + HashPartitioning(Seq('b, 'c), 10), + OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)), + false) + + // TODO: We should check functional dependencies + /* + checkSatisfied( + ClusteredDistribution(Seq('b)), + ClusteredDistribution(Seq('b + 1)), + true) + */ + } + + test("RangePartitioning is the output partitioning") { + // Cases which do not need an exchange between two data properties. + checkSatisfied( + RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), + UnspecifiedDistribution, + true) + + checkSatisfied( + RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), + OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)), + true) + + checkSatisfied( + RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), + OrderedDistribution(Seq('a.asc, 'b.asc)), + true) + + checkSatisfied( + RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), + OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc, 'd.desc)), + true) + + checkSatisfied( + RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), + ClusteredDistribution(Seq('a, 'b, 'c)), + true) + + checkSatisfied( + RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), + ClusteredDistribution(Seq('c, 'b, 'a)), + true) + + checkSatisfied( + RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), + ClusteredDistribution(Seq('b, 'c, 'a, 'd)), + true) + + // Cases which need an exchange between two data properties. + // TODO: We can have an optimization to first sort the dataset + // by a.asc and then sort b, and c in a partition. This optimization + // should tradeoff the benefit of a less number of Exchange operators + // and the parallelism. + checkSatisfied( + RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), + OrderedDistribution(Seq('a.asc, 'b.desc, 'c.asc)), + false) + + checkSatisfied( + RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), + OrderedDistribution(Seq('b.asc, 'a.asc)), + false) + + checkSatisfied( + RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), + ClusteredDistribution(Seq('a, 'b)), + false) + + checkSatisfied( + RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), + ClusteredDistribution(Seq('c, 'd)), + false) + + checkSatisfied( + RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10), + AllTuples, + false) + } +} \ No newline at end of file diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ExpressionEvaluationSuite.scala new file mode 100644 index 000000000..f06618ad1 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ExpressionEvaluationSuite.scala @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package expressions + +import org.scalatest.FunSuite + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.types._ + +/* Implict conversions */ +import org.apache.spark.sql.catalyst.dsl.expressions._ + +class ExpressionEvaluationSuite extends FunSuite { + + test("literals") { + assert((Literal(1) + Literal(1)).apply(null) === 2) + } + + /** + * Checks for three-valued-logic. Based on: + * http://en.wikipedia.org/wiki/Null_(SQL)#Comparisons_with_NULL_and_the_three-valued_logic_.283VL.29 + * + * p q p OR q p AND q p = q + * True True True True True + * True False True False False + * True Unknown True Unknown Unknown + * False True True False False + * False False False False True + * False Unknown Unknown False Unknown + * Unknown True True Unknown Unknown + * Unknown False Unknown False Unknown + * Unknown Unknown Unknown Unknown Unknown + * + * p NOT p + * True False + * False True + * Unknown Unknown + */ + + val notTrueTable = + (true, false) :: + (false, true) :: + (null, null) :: Nil + + test("3VL Not") { + notTrueTable.foreach { + case (v, answer) => + val expr = Not(Literal(v, BooleanType)) + val result = expr.apply(null) + if (result != answer) + fail(s"$expr should not evaluate to $result, expected: $answer") } + } + + booleanLogicTest("AND", _ && _, + (true, true, true) :: + (true, false, false) :: + (true, null, null) :: + (false, true, false) :: + (false, false, false) :: + (false, null, false) :: + (null, true, null) :: + (null, false, false) :: + (null, null, null) :: Nil) + + booleanLogicTest("OR", _ || _, + (true, true, true) :: + (true, false, true) :: + (true, null, true) :: + (false, true, true) :: + (false, false, false) :: + (false, null, null) :: + (null, true, true) :: + (null, false, null) :: + (null, null, null) :: Nil) + + booleanLogicTest("=", _ === _, + (true, true, true) :: + (true, false, false) :: + (true, null, null) :: + (false, true, false) :: + (false, false, true) :: + (false, null, null) :: + (null, true, null) :: + (null, false, null) :: + (null, null, null) :: Nil) + + def booleanLogicTest(name: String, op: (Expression, Expression) => Expression, truthTable: Seq[(Any, Any, Any)]) { + test(s"3VL $name") { + truthTable.foreach { + case (l,r,answer) => + val expr = op(Literal(l, BooleanType), Literal(r, BooleanType)) + val result = expr.apply(null) + if (result != answer) + fail(s"$expr should not evaluate to $result, expected: $answer") + } + } + } +} \ No newline at end of file diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/HiveTypeCoercionSuite.scala new file mode 100644 index 000000000..f595bf7e4 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/HiveTypeCoercionSuite.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package analysis + +import org.scalatest.FunSuite + +import catalyst.types._ + + +class HiveTypeCoercionSuite extends FunSuite { + + val rules = new HiveTypeCoercion { } + import rules._ + + test("tightest common bound for numeric and boolean types") { + def widenTest(t1: DataType, t2: DataType, tightestCommon: Option[DataType]) { + var found = WidenTypes.findTightestCommonType(t1, t2) + assert(found == tightestCommon, + s"Expected $tightestCommon as tightest common type for $t1 and $t2, found $found") + // Test both directions to make sure the widening is symmetric. + found = WidenTypes.findTightestCommonType(t2, t1) + assert(found == tightestCommon, + s"Expected $tightestCommon as tightest common type for $t2 and $t1, found $found") + } + + // Boolean + widenTest(NullType, BooleanType, Some(BooleanType)) + widenTest(BooleanType, BooleanType, Some(BooleanType)) + widenTest(IntegerType, BooleanType, None) + widenTest(LongType, BooleanType, None) + + // Integral + widenTest(NullType, ByteType, Some(ByteType)) + widenTest(NullType, IntegerType, Some(IntegerType)) + widenTest(NullType, LongType, Some(LongType)) + widenTest(ShortType, IntegerType, Some(IntegerType)) + widenTest(ShortType, LongType, Some(LongType)) + widenTest(IntegerType, LongType, Some(LongType)) + widenTest(LongType, LongType, Some(LongType)) + + // Floating point + widenTest(NullType, FloatType, Some(FloatType)) + widenTest(NullType, DoubleType, Some(DoubleType)) + widenTest(FloatType, DoubleType, Some(DoubleType)) + widenTest(FloatType, FloatType, Some(FloatType)) + widenTest(DoubleType, DoubleType, Some(DoubleType)) + + // Integral mixed with floating point. + widenTest(NullType, FloatType, Some(FloatType)) + widenTest(NullType, DoubleType, Some(DoubleType)) + widenTest(IntegerType, FloatType, Some(FloatType)) + widenTest(IntegerType, DoubleType, Some(DoubleType)) + widenTest(IntegerType, DoubleType, Some(DoubleType)) + widenTest(LongType, FloatType, Some(FloatType)) + widenTest(LongType, DoubleType, Some(DoubleType)) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/RuleExecutorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/RuleExecutorSuite.scala new file mode 100644 index 000000000..ff7c15b71 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/RuleExecutorSuite.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package trees + +import org.scalatest.FunSuite + +import expressions._ +import rules._ + +class RuleExecutorSuite extends FunSuite { + object DecrementLiterals extends Rule[Expression] { + def apply(e: Expression): Expression = e transform { + case IntegerLiteral(i) if i > 0 => Literal(i - 1) + } + } + + test("only once") { + object ApplyOnce extends RuleExecutor[Expression] { + val batches = Batch("once", Once, DecrementLiterals) :: Nil + } + + assert(ApplyOnce(Literal(10)) === Literal(9)) + } + + test("to fixed point") { + object ToFixedPoint extends RuleExecutor[Expression] { + val batches = Batch("fixedPoint", FixedPoint(100), DecrementLiterals) :: Nil + } + + assert(ToFixedPoint(Literal(10)) === Literal(0)) + } + + test("to maxIterations") { + object ToFixedPoint extends RuleExecutor[Expression] { + val batches = Batch("fixedPoint", FixedPoint(10), DecrementLiterals) :: Nil + } + + assert(ToFixedPoint(Literal(100)) === Literal(90)) + } +} \ No newline at end of file diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/TreeNodeSuite.scala new file mode 100644 index 000000000..98bb090c2 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/TreeNodeSuite.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package trees + +import scala.collection.mutable.ArrayBuffer + +import expressions._ + +import org.scalatest.{FunSuite} + +class TreeNodeSuite extends FunSuite { + + test("top node changed") { + val after = Literal(1) transform { case Literal(1, _) => Literal(2) } + assert(after === Literal(2)) + } + + test("one child changed") { + val before = Add(Literal(1), Literal(2)) + val after = before transform { case Literal(2, _) => Literal(1) } + + assert(after === Add(Literal(1), Literal(1))) + } + + test("no change") { + val before = Add(Literal(1), Add(Literal(2), Add(Literal(3), Literal(4)))) + val after = before transform { case Literal(5, _) => Literal(1)} + + assert(before === after) + assert(before.map(_.id) === after.map(_.id)) + } + + test("collect") { + val tree = Add(Literal(1), Add(Literal(2), Add(Literal(3), Literal(4)))) + val literals = tree collect {case l: Literal => l} + + assert(literals.size === 4) + (1 to 4).foreach(i => assert(literals contains Literal(i))) + } + + test("pre-order transform") { + val actual = new ArrayBuffer[String]() + val expected = Seq("+", "1", "*", "2", "-", "3", "4") + val expression = Add(Literal(1), Multiply(Literal(2), Subtract(Literal(3), Literal(4)))) + expression transformDown { + case b: BinaryExpression => {actual.append(b.symbol); b} + case l: Literal => {actual.append(l.toString); l} + } + + assert(expected === actual) + } + + test("post-order transform") { + val actual = new ArrayBuffer[String]() + val expected = Seq("1", "2", "3", "4", "-", "*", "+") + val expression = Add(Literal(1), Multiply(Literal(2), Subtract(Literal(3), Literal(4)))) + expression transformUp { + case b: BinaryExpression => {actual.append(b.symbol); b} + case l: Literal => {actual.append(l.toString); l} + } + + assert(expected === actual) + } +} \ No newline at end of file diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala new file mode 100644 index 000000000..7ce42b2b0 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package catalyst +package optimizer + +import types.IntegerType +import util._ +import plans.logical.{LogicalPlan, LocalRelation} +import rules._ +import expressions._ +import dsl.plans._ +import dsl.expressions._ + +class ConstantFoldingSuite extends OptimizerTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Subqueries", Once, + EliminateSubqueries) :: + Batch("ConstantFolding", Once, + ConstantFolding, + BooleanSimplification) :: Nil + } + + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + + test("eliminate subqueries") { + val originalQuery = + testRelation + .subquery('y) + .select('a) + + val optimized = Optimize(originalQuery.analyze) + val correctAnswer = + testRelation + .select('a.attr) + .analyze + + comparePlans(optimized, correctAnswer) + } + + /** + * Unit tests for constant folding in expressions. + */ + test("Constant folding test: expressions only have literals") { + val originalQuery = + testRelation + .select( + Literal(2) + Literal(3) + Literal(4) as Symbol("2+3+4"), + Literal(2) * Literal(3) + Literal(4) as Symbol("2*3+4"), + Literal(2) * (Literal(3) + Literal(4)) as Symbol("2*(3+4)")) + .where( + Literal(1) === Literal(1) && + Literal(2) > Literal(3) || + Literal(3) > Literal(2) ) + .groupBy( + Literal(2) * Literal(3) - Literal(6) / (Literal(4) - Literal(2)) + )(Literal(9) / Literal(3) as Symbol("9/3")) + + val optimized = Optimize(originalQuery.analyze) + + val correctAnswer = + testRelation + .select( + Literal(9) as Symbol("2+3+4"), + Literal(10) as Symbol("2*3+4"), + Literal(14) as Symbol("2*(3+4)")) + .where(Literal(true)) + .groupBy(Literal(3))(Literal(3) as Symbol("9/3")) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("Constant folding test: expressions have attribute references and literals in " + + "arithmetic operations") { + val originalQuery = + testRelation + .select( + Literal(2) + Literal(3) + 'a as Symbol("c1"), + 'a + Literal(2) + Literal(3) as Symbol("c2"), + Literal(2) * 'a + Literal(4) as Symbol("c3"), + 'a * (Literal(3) + Literal(4)) as Symbol("c4")) + + val optimized = Optimize(originalQuery.analyze) + + val correctAnswer = + testRelation + .select( + Literal(5) + 'a as Symbol("c1"), + 'a + Literal(2) + Literal(3) as Symbol("c2"), + Literal(2) * 'a + Literal(4) as Symbol("c3"), + 'a * (Literal(7)) as Symbol("c4")) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("Constant folding test: expressions have attribute references and literals in " + + "predicates") { + val originalQuery = + testRelation + .where( + (('a > 1 && Literal(1) === Literal(1)) || + ('a < 10 && Literal(1) === Literal(2)) || + (Literal(1) === Literal(1) && 'b > 1) || + (Literal(1) === Literal(2) && 'b < 10)) && + (('a > 1 || Literal(1) === Literal(1)) && + ('a < 10 || Literal(1) === Literal(2)) && + (Literal(1) === Literal(1) || 'b > 1) && + (Literal(1) === Literal(2) || 'b < 10))) + + val optimized = Optimize(originalQuery.analyze) + + val correctAnswer = + testRelation + .where(('a > 1 || 'b > 1) && ('a < 10 && 'b < 10)) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("Constant folding test: expressions have foldable functions") { + val originalQuery = + testRelation + .select( + Cast(Literal("2"), IntegerType) + Literal(3) + 'a as Symbol("c1"), + Coalesce(Seq(Cast(Literal("abc"), IntegerType), Literal(3))) as Symbol("c2")) + + val optimized = Optimize(originalQuery.analyze) + + val correctAnswer = + testRelation + .select( + Literal(5) + 'a as Symbol("c1"), + Literal(3) as Symbol("c2")) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("Constant folding test: expressions have nonfoldable functions") { + val originalQuery = + testRelation + .select( + Rand + Literal(1) as Symbol("c1"), + Sum('a) as Symbol("c2")) + + val optimized = Optimize(originalQuery.analyze) + + val correctAnswer = + testRelation + .select( + Rand + Literal(1.0) as Symbol("c1"), + Sum('a) as Symbol("c2")) + .analyze + + comparePlans(optimized, correctAnswer) + } +} \ No newline at end of file diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala new file mode 100644 index 000000000..cd611b3fb --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -0,0 +1,222 @@ +package org.apache.spark.sql +package catalyst +package optimizer + +import expressions._ +import plans.logical._ +import rules._ +import util._ + +import dsl.plans._ +import dsl.expressions._ + +class FilterPushdownSuite extends OptimizerTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Subqueries", Once, + EliminateSubqueries) :: + Batch("Filter Pushdown", Once, + EliminateSubqueries, + CombineFilters, + PushPredicateThroughProject, + PushPredicateThroughInnerJoin) :: Nil + } + + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + + // This test already passes. + test("eliminate subqueries") { + val originalQuery = + testRelation + .subquery('y) + .select('a) + + val optimized = Optimize(originalQuery.analyze) + val correctAnswer = + testRelation + .select('a.attr) + .analyze + + comparePlans(optimized, correctAnswer) + } + + // After this line is unimplemented. + test("simple push down") { + val originalQuery = + testRelation + .select('a) + .where('a === 1) + + val optimized = Optimize(originalQuery.analyze) + val correctAnswer = + testRelation + .where('a === 1) + .select('a) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("can't push without rewrite") { + val originalQuery = + testRelation + .select('a + 'b as 'e) + .where('e === 1) + .analyze + + val optimized = Optimize(originalQuery.analyze) + val correctAnswer = + testRelation + .where('a + 'b === 1) + .select('a + 'b as 'e) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("filters: combines filters") { + val originalQuery = testRelation + .select('a) + .where('a === 1) + .where('a === 2) + + val optimized = Optimize(originalQuery.analyze) + val correctAnswer = + testRelation + .where('a === 1 && 'a === 2) + .select('a).analyze + + + comparePlans(optimized, correctAnswer) + } + + + test("joins: push to either side") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + + val originalQuery = { + x.join(y) + .where("x.b".attr === 1) + .where("y.b".attr === 2) + } + + val optimized = Optimize(originalQuery.analyze) + val left = testRelation.where('b === 1) + val right = testRelation.where('b === 2) + val correctAnswer = + left.join(right).analyze + + comparePlans(optimized, correctAnswer) + } + + test("joins: push to one side") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + + val originalQuery = { + x.join(y) + .where("x.b".attr === 1) + } + + val optimized = Optimize(originalQuery.analyze) + val left = testRelation.where('b === 1) + val right = testRelation + val correctAnswer = + left.join(right).analyze + + comparePlans(optimized, correctAnswer) + } + + test("joins: rewrite filter to push to either side") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + + val originalQuery = { + x.join(y) + .where("x.b".attr === 1 && "y.b".attr === 2) + } + + val optimized = Optimize(originalQuery.analyze) + val left = testRelation.where('b === 1) + val right = testRelation.where('b === 2) + val correctAnswer = + left.join(right).analyze + + comparePlans(optimized, correctAnswer) + } + + test("joins: can't push down") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + + val originalQuery = { + x.join(y, condition = Some("x.b".attr === "y.b".attr)) + } + val optimized = Optimize(originalQuery.analyze) + + comparePlans(optimizer.EliminateSubqueries(originalQuery.analyze), optimized) + } + + test("joins: conjunctive predicates") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + + val originalQuery = { + x.join(y) + .where(("x.b".attr === "y.b".attr) && ("x.a".attr === 1) && ("y.a".attr === 1)) + } + + val optimized = Optimize(originalQuery.analyze) + val left = testRelation.where('a === 1).subquery('x) + val right = testRelation.where('a === 1).subquery('y) + val correctAnswer = + left.join(right, condition = Some("x.b".attr === "y.b".attr)) + .analyze + + comparePlans(optimized, optimizer.EliminateSubqueries(correctAnswer)) + } + + test("joins: conjunctive predicates #2") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + + val originalQuery = { + x.join(y) + .where(("x.b".attr === "y.b".attr) && ("x.a".attr === 1)) + } + + val optimized = Optimize(originalQuery.analyze) + val left = testRelation.where('a === 1).subquery('x) + val right = testRelation.subquery('y) + val correctAnswer = + left.join(right, condition = Some("x.b".attr === "y.b".attr)) + .analyze + + comparePlans(optimized, optimizer.EliminateSubqueries(correctAnswer)) + } + + test("joins: conjunctive predicates #3") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + val z = testRelation.subquery('z) + + val originalQuery = { + z.join(x.join(y)) + .where(("x.b".attr === "y.b".attr) && ("x.a".attr === 1) && ("z.a".attr >= 3) && ("z.a".attr === "x.b".attr)) + } + + val optimized = Optimize(originalQuery.analyze) + val lleft = testRelation.where('a >= 3).subquery('z) + val left = testRelation.where('a === 1).subquery('x) + val right = testRelation.subquery('y) + val correctAnswer = + lleft.join( + left.join(right, condition = Some("x.b".attr === "y.b".attr)), + condition = Some("z.a".attr === "x.b".attr)) + .analyze + + comparePlans(optimized, optimizer.EliminateSubqueries(correctAnswer)) + } +} \ No newline at end of file diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerTest.scala new file mode 100644 index 000000000..7b3653d0f --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerTest.scala @@ -0,0 +1,44 @@ +package org.apache.spark.sql +package catalyst +package optimizer + +import org.scalatest.FunSuite + +import types.IntegerType +import util._ +import plans.logical.{LogicalPlan, LocalRelation} +import expressions._ +import dsl._ + +/* Implicit conversions for creating query plans */ + +/** + * Provides helper methods for comparing plans produced by optimization rules with the expected + * result + */ +class OptimizerTest extends FunSuite { + + /** + * Since attribute references are given globally unique ids during analysis, + * we must normalize them to check if two different queries are identical. + */ + protected def normalizeExprIds(plan: LogicalPlan) = { + val minId = plan.flatMap(_.expressions.flatMap(_.references).map(_.exprId.id)).min + plan transformAllExpressions { + case a: AttributeReference => + AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(a.exprId.id - minId)) + } + } + + /** Fails the test if the two plans do not match */ + protected def comparePlans(plan1: LogicalPlan, plan2: LogicalPlan) { + val normalized1 = normalizeExprIds(plan1) + val normalized2 = normalizeExprIds(plan2) + if (normalized1 != normalized2) + fail( + s""" + |== FAIL: Plans do not match === + |${sideBySide(normalized1.treeString, normalized2.treeString).mkString("\n")} + """.stripMargin) + } +} \ No newline at end of file diff --git a/sql/core/pom.xml b/sql/core/pom.xml new file mode 100644 index 000000000..e367edfb1 --- /dev/null +++ b/sql/core/pom.xml @@ -0,0 +1,76 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent + 1.0.0-SNAPSHOT + ../../pom.xml + + + org.apache.spark + spark-sql_2.10 + jar + Spark Project SQL + http://spark.apache.org/ + + + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + + + org.apache.spark + spark-catalyst_${scala.binary.version} + ${project.version} + + + com.twitter + parquet-column + ${parquet.version} + + + com.twitter + parquet-hadoop + ${parquet.version} + + + org.scalatest + scalatest_${scala.binary.version} + test + + + org.scalacheck + scalacheck_${scala.binary.version} + test + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + + org.scalatest + scalatest-maven-plugin + + + + diff --git a/sql/core/src/main/scala/org/apache/spark/rdd/PartitionLocalRDDFunctions.scala b/sql/core/src/main/scala/org/apache/spark/rdd/PartitionLocalRDDFunctions.scala new file mode 100644 index 000000000..b8b9e5839 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/rdd/PartitionLocalRDDFunctions.scala @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import scala.language.implicitConversions + +import scala.reflect._ +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark._ +import org.apache.spark.Aggregator +import org.apache.spark.SparkContext._ +import org.apache.spark.util.collection.AppendOnlyMap + +/** + * Extra functions on RDDs that perform only local operations. These can be used when data has + * already been partitioned correctly. + */ +private[spark] class PartitionLocalRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)]) + extends Logging + with Serializable { + + /** + * Cogroup corresponding partitions of `this` and `other`. These two RDDs should have + * the same number of partitions. Partitions of these two RDDs are cogrouped + * according to the indexes of partitions. If we have two RDDs and + * each of them has n partitions, we will cogroup the partition i from `this` + * with the partition i from `other`. + * This function will not introduce a shuffling operation. + */ + def cogroupLocally[W](other: RDD[(K, W)]): RDD[(K, (Seq[V], Seq[W]))] = { + val cg = self.zipPartitions(other)((iter1:Iterator[(K, V)], iter2:Iterator[(K, W)]) => { + val map = new AppendOnlyMap[K, Seq[ArrayBuffer[Any]]] + + val update: (Boolean, Seq[ArrayBuffer[Any]]) => Seq[ArrayBuffer[Any]] = (hadVal, oldVal) => { + if (hadVal) oldVal else Array.fill(2)(new ArrayBuffer[Any]) + } + + val getSeq = (k: K) => { + map.changeValue(k, update) + } + + iter1.foreach { kv => getSeq(kv._1)(0) += kv._2 } + iter2.foreach { kv => getSeq(kv._1)(1) += kv._2 } + + map.iterator + }).mapValues { case Seq(vs, ws) => (vs.asInstanceOf[Seq[V]], ws.asInstanceOf[Seq[W]])} + + cg + } + + /** + * Group the values for each key within a partition of the RDD into a single sequence. + * This function will not introduce a shuffling operation. + */ + def groupByKeyLocally(): RDD[(K, Seq[V])] = { + def createCombiner(v: V) = ArrayBuffer(v) + def mergeValue(buf: ArrayBuffer[V], v: V) = buf += v + val aggregator = new Aggregator[K, V, ArrayBuffer[V]](createCombiner, mergeValue, _ ++ _) + val bufs = self.mapPartitionsWithContext((context, iter) => { + new InterruptibleIterator(context, aggregator.combineValuesByKey(iter, context)) + }, preservesPartitioning = true) + bufs.asInstanceOf[RDD[(K, Seq[V])]] + } + + /** + * Join corresponding partitions of `this` and `other`. + * If we have two RDDs and each of them has n partitions, + * we will join the partition i from `this` with the partition i from `other`. + * This function will not introduce a shuffling operation. + */ + def joinLocally[W](other: RDD[(K, W)]): RDD[(K, (V, W))] = { + cogroupLocally(other).flatMapValues { + case (vs, ws) => for (v <- vs.iterator; w <- ws.iterator) yield (v, w) + } + } +} + +private[spark] object PartitionLocalRDDFunctions { + implicit def rddToPartitionLocalRDDFunctions[K: ClassTag, V: ClassTag](rdd: RDD[(K, V)]) = + new PartitionLocalRDDFunctions(rdd) +} + + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala new file mode 100644 index 000000000..587cc7487 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.language.implicitConversions +import scala.reflect.runtime.universe.TypeTag + +import org.apache.spark.{SparkContext, SparkConf} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.dsl +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.optimizer.Optimizer +import org.apache.spark.sql.catalyst.planning.QueryPlanner +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, NativeCommand, WriteToFile} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.execution._ + +/** + * ALPHA COMPONENT + * + * The entry point for running relational queries using Spark. Allows the creation of [[SchemaRDD]] + * objects and the execution of SQL queries. + * + * @groupname userf Spark SQL Functions + * @groupname Ungrouped Support functions for language integrated queries. + */ +class SQLContext(@transient val sparkContext: SparkContext) + extends Logging + with dsl.ExpressionConversions + with Serializable { + + self => + + @transient + protected[sql] lazy val catalog: Catalog = new SimpleCatalog + @transient + protected[sql] lazy val analyzer: Analyzer = + new Analyzer(catalog, EmptyFunctionRegistry, caseSensitive = true) + @transient + protected[sql] val optimizer = Optimizer + @transient + protected[sql] val parser = new catalyst.SqlParser + + protected[sql] def parseSql(sql: String): LogicalPlan = parser(sql) + protected[sql] def executeSql(sql: String): this.QueryExecution = executePlan(parseSql(sql)) + protected[sql] def executePlan(plan: LogicalPlan): this.QueryExecution = + new this.QueryExecution { val logical = plan } + + /** + * EXPERIMENTAL + * + * Allows catalyst LogicalPlans to be executed as a SchemaRDD. Note that the LogicalPlan + * interface is considered internal, and thus not guranteed to be stable. As a result, using + * them directly is not reccomended. + */ + implicit def logicalPlanToSparkQuery(plan: LogicalPlan): SchemaRDD = new SchemaRDD(this, plan) + + /** + * Creates a SchemaRDD from an RDD of case classes. + * + * @group userf + */ + implicit def createSchemaRDD[A <: Product: TypeTag](rdd: RDD[A]) = + new SchemaRDD(this, SparkLogicalPlan(ExistingRdd.fromProductRdd(rdd))) + + /** + * Loads a parequet file, returning the result as a [[SchemaRDD]]. + * + * @group userf + */ + def parquetFile(path: String): SchemaRDD = + new SchemaRDD(this, parquet.ParquetRelation("ParquetFile", path)) + + + /** + * Registers the given RDD as a temporary table in the catalog. Temporary tables exist only + * during the lifetime of this instance of SQLContext. + * + * @group userf + */ + def registerRDDAsTable(rdd: SchemaRDD, tableName: String): Unit = { + catalog.registerTable(None, tableName, rdd.logicalPlan) + } + + /** + * Executes a SQL query using Spark, returning the result as a SchemaRDD. + * + * @group userf + */ + def sql(sqlText: String): SchemaRDD = { + val result = new SchemaRDD(this, parseSql(sqlText)) + // We force query optimization to happen right away instead of letting it happen lazily like + // when using the query DSL. This is so DDL commands behave as expected. This is only + // generates the RDD lineage for DML queries, but do not perform any execution. + result.queryExecution.toRdd + result + } + + protected[sql] class SparkPlanner extends SparkStrategies { + val sparkContext = self.sparkContext + + val strategies: Seq[Strategy] = + TopK :: + PartialAggregation :: + SparkEquiInnerJoin :: + BasicOperators :: + CartesianProduct :: + BroadcastNestedLoopJoin :: Nil + } + + @transient + protected[sql] val planner = new SparkPlanner + + /** + * Prepares a planned SparkPlan for execution by binding references to specific ordinals, and + * inserting shuffle operations as needed. + */ + @transient + protected[sql] val prepareForExecution = new RuleExecutor[SparkPlan] { + val batches = + Batch("Add exchange", Once, AddExchange) :: + Batch("Prepare Expressions", Once, new BindReferences[SparkPlan]) :: Nil + } + + /** + * The primary workflow for executing relational queries using Spark. Designed to allow easy + * access to the intermediate phases of query execution for developers. + */ + protected abstract class QueryExecution { + def logical: LogicalPlan + + lazy val analyzed = analyzer(logical) + lazy val optimizedPlan = optimizer(analyzed) + // TODO: Don't just pick the first one... + lazy val sparkPlan = planner(optimizedPlan).next() + lazy val executedPlan: SparkPlan = prepareForExecution(sparkPlan) + + /** Internal version of the RDD. Avoids copies and has no schema */ + lazy val toRdd: RDD[Row] = executedPlan.execute() + + protected def stringOrError[A](f: => A): String = + try f.toString catch { case e: Throwable => e.toString } + + override def toString: String = + s"""== Logical Plan == + |${stringOrError(analyzed)} + |== Optimized Logical Plan + |${stringOrError(optimizedPlan)} + |== Physical Plan == + |${stringOrError(executedPlan)} + """.stripMargin.trim + + /** + * Runs the query after interposing operators that print the result of each intermediate step. + */ + def debugExec() = DebugQuery(executedPlan).execute().collect() + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala new file mode 100644 index 000000000..91c3aaa2b --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -0,0 +1,342 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql + +import org.apache.spark.{OneToOneDependency, Dependency, Partition, TaskContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.types.BooleanType + +/** + * ALPHA COMPONENT + * + * An RDD of [[Row]] objects that has an associated schema. In addition to standard RDD functions, + * SchemaRDDs can be used in relational queries, as shown in the examples below. + * + * Importing a SQLContext brings an implicit into scope that automatically converts a standard RDD + * whose elements are scala case classes into a SchemaRDD. This conversion can also be done + * explicitly using the `createSchemaRDD` function on a [[SQLContext]]. + * + * A `SchemaRDD` can also be created by loading data in from external sources, for example, + * by using the `parquetFile` method on [[SQLContext]]. + * + * == SQL Queries == + * A SchemaRDD can be registered as a table in the [[SQLContext]] that was used to create it. Once + * an RDD has been registered as a table, it can be used in the FROM clause of SQL statements. + * + * {{{ + * // One method for defining the schema of an RDD is to make a case class with the desired column + * // names and types. + * case class Record(key: Int, value: String) + * + * val sc: SparkContext // An existing spark context. + * val sqlContext = new SQLContext(sc) + * + * // Importing the SQL context gives access to all the SQL functions and implicit conversions. + * import sqlContext._ + * + * val rdd = sc.parallelize((1 to 100).map(i => Record(i, s"val_\$i"))) + * // Any RDD containing case classes can be registered as a table. The schema of the table is + * // automatically inferred using scala reflection. + * rdd.registerAsTable("records") + * + * val results: SchemaRDD = sql("SELECT * FROM records") + * }}} + * + * == Language Integrated Queries == + * + * {{{ + * + * case class Record(key: Int, value: String) + * + * val sc: SparkContext // An existing spark context. + * val sqlContext = new SQLContext(sc) + * + * // Importing the SQL context gives access to all the SQL functions and implicit conversions. + * import sqlContext._ + * + * val rdd = sc.parallelize((1 to 100).map(i => Record(i, "val_" + i))) + * + * // Example of language integrated queries. + * rdd.where('key === 1).orderBy('value.asc).select('key).collect() + * }}} + * + * @todo There is currently no support for creating SchemaRDDs from either Java or Python RDDs. + * + * @groupname Query Language Integrated Queries + * @groupdesc Query Functions that create new queries from SchemaRDDs. The + * result of all query functions is also a SchemaRDD, allowing multiple operations to be + * chained using a builder pattern. + * @groupprio Query -2 + * @groupname schema SchemaRDD Functions + * @groupprio schema -1 + * @groupname Ungrouped Base RDD Functions + */ +class SchemaRDD( + @transient val sqlContext: SQLContext, + @transient val logicalPlan: LogicalPlan) + extends RDD[Row](sqlContext.sparkContext, Nil) { + + /** + * A lazily computed query execution workflow. All other RDD operations are passed + * through to the RDD that is produced by this workflow. + * + * We want this to be lazy because invoking the whole query optimization pipeline can be + * expensive. + */ + @transient + protected[spark] lazy val queryExecution = sqlContext.executePlan(logicalPlan) + + override def toString = + s"""${super.toString} + |== Query Plan == + |${queryExecution.executedPlan}""".stripMargin.trim + + // ========================================================================================= + // RDD functions: Copy the interal row representation so we present immutable data to users. + // ========================================================================================= + + override def compute(split: Partition, context: TaskContext): Iterator[Row] = + firstParent[Row].compute(split, context).map(_.copy()) + + override def getPartitions: Array[Partition] = firstParent[Row].partitions + + override protected def getDependencies: Seq[Dependency[_]] = + List(new OneToOneDependency(queryExecution.toRdd)) + + + // ======================================================================= + // Query DSL + // ======================================================================= + + /** + * Changes the output of this relation to the given expressions, similar to the `SELECT` clause + * in SQL. + * + * {{{ + * schemaRDD.select('a, 'b + 'c, 'd as 'aliasedName) + * }}} + * + * @param exprs a set of logical expression that will be evaluated for each input row. + * + * @group Query + */ + def select(exprs: NamedExpression*): SchemaRDD = + new SchemaRDD(sqlContext, Project(exprs, logicalPlan)) + + /** + * Filters the ouput, only returning those rows where `condition` evaluates to true. + * + * {{{ + * schemaRDD.where('a === 'b) + * schemaRDD.where('a === 1) + * schemaRDD.where('a + 'b > 10) + * }}} + * + * @group Query + */ + def where(condition: Expression): SchemaRDD = + new SchemaRDD(sqlContext, Filter(condition, logicalPlan)) + + /** + * Performs a relational join on two SchemaRDDs + * + * @param otherPlan the [[SchemaRDD]] that should be joined with this one. + * @param joinType One of `Inner`, `LeftOuter`, `RightOuter`, or `FullOuter`. Defaults to `Inner.` + * @param condition An optional condition for the join operation. This is equivilent to the `ON` + * clause in standard SQL. In the case of `Inner` joins, specifying a + * `condition` is equivilent to adding `where` clauses after the `join`. + * + * @group Query + */ + def join( + otherPlan: SchemaRDD, + joinType: JoinType = Inner, + condition: Option[Expression] = None): SchemaRDD = + new SchemaRDD(sqlContext, Join(logicalPlan, otherPlan.logicalPlan, joinType, condition)) + + /** + * Sorts the results by the given expressions. + * {{{ + * schemaRDD.orderBy('a) + * schemaRDD.orderBy('a, 'b) + * schemaRDD.orderBy('a.asc, 'b.desc) + * }}} + * + * @group Query + */ + def orderBy(sortExprs: SortOrder*): SchemaRDD = + new SchemaRDD(sqlContext, Sort(sortExprs, logicalPlan)) + + /** + * Performs a grouping followed by an aggregation. + * + * {{{ + * schemaRDD.groupBy('year)(Sum('sales) as 'totalSales) + * }}} + * + * @group Query + */ + def groupBy(groupingExprs: Expression*)(aggregateExprs: Expression*): SchemaRDD = { + val aliasedExprs = aggregateExprs.map { + case ne: NamedExpression => ne + case e => Alias(e, e.toString)() + } + new SchemaRDD(sqlContext, Aggregate(groupingExprs, aliasedExprs, logicalPlan)) + } + + /** + * Applies a qualifier to the attributes of this relation. Can be used to disambiguate attributes + * with the same name, for example, when peforming self-joins. + * + * {{{ + * val x = schemaRDD.where('a === 1).subquery('x) + * val y = schemaRDD.where('a === 2).subquery('y) + * x.join(y).where("x.a".attr === "y.a".attr), + * }}} + * + * @group Query + */ + def subquery(alias: Symbol) = + new SchemaRDD(sqlContext, Subquery(alias.name, logicalPlan)) + + /** + * Combines the tuples of two RDDs with the same schema, keeping duplicates. + * + * @group Query + */ + def unionAll(otherPlan: SchemaRDD) = + new SchemaRDD(sqlContext, Union(logicalPlan, otherPlan.logicalPlan)) + + /** + * Filters tuples using a function over the value of the specified column. + * + * {{{ + * schemaRDD.sfilter('a)((a: Int) => ...) + * }}} + * + * @group Query + */ + def where[T1](arg1: Symbol)(udf: (T1) => Boolean) = + new SchemaRDD( + sqlContext, + Filter(ScalaUdf(udf, BooleanType, Seq(UnresolvedAttribute(arg1.name))), logicalPlan)) + + /** + * EXPERIMENTAL + * + * Filters tuples using a function over a `Dynamic` version of a given Row. DynamicRows use + * scala's Dynamic trait to emulate an ORM of in a dynamically typed language. Since the type of + * the column is not known at compile time, all attributes are converted to strings before + * being passed to the function. + * + * {{{ + * schemaRDD.where(r => r.firstName == "Bob" && r.lastName == "Smith") + * }}} + * + * @group Query + */ + def where(dynamicUdf: (DynamicRow) => Boolean) = + new SchemaRDD( + sqlContext, + Filter(ScalaUdf(dynamicUdf, BooleanType, Seq(WrapDynamic(logicalPlan.output))), logicalPlan)) + + /** + * EXPERIMENTAL + * + * Returns a sampled version of the underlying dataset. + * + * @group Query + */ + def sample( + fraction: Double, + withReplacement: Boolean = true, + seed: Int = (math.random * 1000).toInt) = + new SchemaRDD(sqlContext, Sample(fraction, withReplacement, seed, logicalPlan)) + + /** + * EXPERIMENTAL + * + * Applies the given Generator, or table generating function, to this relation. + * + * @param generator A table generating function. The API for such functions is likely to change + * in future releases + * @param join when set to true, each output row of the generator is joined with the input row + * that produced it. + * @param outer when set to true, at least one row will be produced for each input row, similar to + * an `OUTER JOIN` in SQL. When no output rows are produced by the generator for a + * given row, a single row will be output, with `NULL` values for each of the + * generated columns. + * @param alias an optional alias that can be used as qualif for the attributes that are produced + * by this generate operation. + * + * @group Query + */ + def generate( + generator: Generator, + join: Boolean = false, + outer: Boolean = false, + alias: Option[String] = None) = + new SchemaRDD(sqlContext, Generate(generator, join, outer, None, logicalPlan)) + + /** + * EXPERIMENTAL + * + * Adds the rows from this RDD to the specified table. Note in a standard [[SQLContext]] there is + * no notion of persistent tables, and thus queries that contain this operator will fail to + * optimize. When working with an extension of a SQLContext that has a persistent catalog, such + * as a `HiveContext`, this operation will result in insertions to the table specified. + * + * @group schema + */ + def insertInto(tableName: String, overwrite: Boolean = false) = + new SchemaRDD( + sqlContext, + InsertIntoTable(UnresolvedRelation(None, tableName), Map.empty, logicalPlan, overwrite)) + + /** + * Saves the contents of this `SchemaRDD` as a parquet file, preserving the schema. Files that + * are written out using this method can be read back in as a SchemaRDD using the ``function + * + * @group schema + */ + def saveAsParquetFile(path: String): Unit = { + sqlContext.executePlan(WriteToFile(path, logicalPlan)).toRdd + } + + /** + * Registers this RDD as a temporary table using the given name. The lifetime of this temporary + * table is tied to the [[SQLContext]] that was used to create this SchemaRDD. + * + * @group schema + */ + def registerAsTable(tableName: String): Unit = { + sqlContext.registerRDDAsTable(this, tableName) + } + + /** + * Returns this RDD as a SchemaRDD. + * @group schema + */ + def toSchemaRDD = this + + def analyze = sqlContext.analyzer(logicalPlan) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala new file mode 100644 index 000000000..72dc5ec6a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package execution + +import java.nio.ByteBuffer + +import com.esotericsoftware.kryo.{Kryo, Serializer} +import com.esotericsoftware.kryo.io.{Output, Input} + +import org.apache.spark.{SparkConf, RangePartitioner, HashPartitioner} +import org.apache.spark.rdd.ShuffledRDD +import org.apache.spark.serializer.KryoSerializer +import org.apache.spark.util.MutablePair + +import catalyst.rules.Rule +import catalyst.errors._ +import catalyst.expressions._ +import catalyst.plans.physical._ + +private class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(conf) { + override def newKryo(): Kryo = { + val kryo = new Kryo + kryo.setRegistrationRequired(true) + kryo.register(classOf[MutablePair[_,_]]) + kryo.register(classOf[Array[Any]]) + kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericRow]) + kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericMutableRow]) + kryo.register(classOf[scala.collection.mutable.ArrayBuffer[_]]) + kryo.register(classOf[scala.math.BigDecimal], new BigDecimalSerializer) + kryo.setReferences(false) + kryo.setClassLoader(this.getClass.getClassLoader) + kryo + } +} + +private class BigDecimalSerializer extends Serializer[BigDecimal] { + def write(kryo: Kryo, output: Output, bd: math.BigDecimal) { + // TODO: There are probably more efficient representations than strings... + output.writeString(bd.toString) + } + + def read(kryo: Kryo, input: Input, tpe: Class[BigDecimal]): BigDecimal = { + BigDecimal(input.readString()) + } +} + +case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends UnaryNode { + + override def outputPartitioning = newPartitioning + + def output = child.output + + def execute() = attachTree(this , "execute") { + newPartitioning match { + case HashPartitioning(expressions, numPartitions) => { + // TODO: Eliminate redundant expressions in grouping key and value. + val rdd = child.execute().mapPartitions { iter => + val hashExpressions = new MutableProjection(expressions) + val mutablePair = new MutablePair[Row, Row]() + iter.map(r => mutablePair.update(hashExpressions(r), r)) + } + val part = new HashPartitioner(numPartitions) + val shuffled = new ShuffledRDD[Row, Row, MutablePair[Row, Row]](rdd, part) + shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false))) + shuffled.map(_._2) + } + case RangePartitioning(sortingExpressions, numPartitions) => { + // TODO: RangePartitioner should take an Ordering. + implicit val ordering = new RowOrdering(sortingExpressions) + + val rdd = child.execute().mapPartitions { iter => + val mutablePair = new MutablePair[Row, Null](null, null) + iter.map(row => mutablePair.update(row, null)) + } + val part = new RangePartitioner(numPartitions, rdd, ascending = true) + val shuffled = new ShuffledRDD[Row, Null, MutablePair[Row, Null]](rdd, part) + shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false))) + + shuffled.map(_._1) + } + case SinglePartition => + child.execute().coalesce(1, true) + + case _ => sys.error(s"Exchange not implemented for $newPartitioning") + // TODO: Handle BroadcastPartitioning. + } + } +} + +/** + * Ensures that the [[catalyst.plans.physical.Partitioning Partitioning]] of input data meets the + * [[catalyst.plans.physical.Distribution Distribution]] requirements for each operator by inserting + * [[Exchange]] Operators where required. + */ +object AddExchange extends Rule[SparkPlan] { + // TODO: Determine the number of partitions. + val numPartitions = 8 + + def apply(plan: SparkPlan): SparkPlan = plan.transformUp { + case operator: SparkPlan => + // Check if every child's outputPartitioning satisfies the corresponding + // required data distribution. + def meetsRequirements = + !operator.requiredChildDistribution.zip(operator.children).map { + case (required, child) => + val valid = child.outputPartitioning.satisfies(required) + logger.debug( + s"${if (valid) "Valid" else "Invalid"} distribution," + + s"required: $required current: ${child.outputPartitioning}") + valid + }.exists(!_) + + // Check if outputPartitionings of children are compatible with each other. + // It is possible that every child satisfies its required data distribution + // but two children have incompatible outputPartitionings. For example, + // A dataset is range partitioned by "a.asc" (RangePartitioning) and another + // dataset is hash partitioned by "a" (HashPartitioning). Tuples in these two + // datasets are both clustered by "a", but these two outputPartitionings are not + // compatible. + // TODO: ASSUMES TRANSITIVITY? + def compatible = + !operator.children + .map(_.outputPartitioning) + .sliding(2) + .map { + case Seq(a) => true + case Seq(a,b) => a compatibleWith b + }.exists(!_) + + // Check if the partitioning we want to ensure is the same as the child's output + // partitioning. If so, we do not need to add the Exchange operator. + def addExchangeIfNecessary(partitioning: Partitioning, child: SparkPlan) = + if (child.outputPartitioning != partitioning) Exchange(partitioning, child) else child + + if (meetsRequirements && compatible) { + operator + } else { + // At least one child does not satisfies its required data distribution or + // at least one child's outputPartitioning is not compatible with another child's + // outputPartitioning. In this case, we need to add Exchange operators. + val repartitionedChildren = operator.requiredChildDistribution.zip(operator.children).map { + case (AllTuples, child) => + addExchangeIfNecessary(SinglePartition, child) + case (ClusteredDistribution(clustering), child) => + addExchangeIfNecessary(HashPartitioning(clustering, numPartitions), child) + case (OrderedDistribution(ordering), child) => + addExchangeIfNecessary(RangePartitioning(ordering, numPartitions), child) + case (UnspecifiedDistribution, child) => child + case (dist, _) => sys.error(s"Don't know how to ensure $dist") + } + operator.withNewChildren(repartitionedChildren) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala new file mode 100644 index 000000000..c1da3653c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package execution + +import catalyst.expressions._ +import catalyst.types._ + +/** + * Applies a [[catalyst.expressions.Generator Generator]] to a stream of input rows, combining the + * output of each into a new stream of rows. This operation is similar to a `flatMap` in functional + * programming with one important additional feature, which allows the input rows to be joined with + * their output. + * @param join when true, each output row is implicitly joined with the input tuple that produced + * it. + * @param outer when true, each input row will be output at least once, even if the output of the + * given `generator` is empty. `outer` has no effect when `join` is false. + */ +case class Generate( + generator: Generator, + join: Boolean, + outer: Boolean, + child: SparkPlan) + extends UnaryNode { + + def output = + if (join) child.output ++ generator.output else generator.output + + def execute() = { + if (join) { + child.execute().mapPartitions { iter => + val nullValues = Seq.fill(generator.output.size)(Literal(null)) + // Used to produce rows with no matches when outer = true. + val outerProjection = + new Projection(child.output ++ nullValues, child.output) + + val joinProjection = + new Projection(child.output ++ generator.output, child.output ++ generator.output) + val joinedRow = new JoinedRow + + iter.flatMap {row => + val outputRows = generator(row) + if (outer && outputRows.isEmpty) { + outerProjection(row) :: Nil + } else { + outputRows.map(or => joinProjection(joinedRow(row, or))) + } + } + } + } else { + child.execute().mapPartitions(iter => iter.flatMap(generator)) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecutionException.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecutionException.scala new file mode 100644 index 000000000..7ce8608d2 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecutionException.scala @@ -0,0 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package execution + +class QueryExecutionException(message: String) extends Exception(message) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala new file mode 100644 index 000000000..5626181d1 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package execution + +import org.apache.spark.rdd.RDD + +import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation +import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.catalyst.trees + +abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging { + self: Product => + + // TODO: Move to `DistributedPlan` + /** Specifies how data is partitioned across different nodes in the cluster. */ + def outputPartitioning: Partitioning = UnknownPartitioning(0) // TODO: WRONG WIDTH! + /** Specifies any partition requirements on the input data for this operator. */ + def requiredChildDistribution: Seq[Distribution] = + Seq.fill(children.size)(UnspecifiedDistribution) + + /** + * Runs this query returning the result as an RDD. + */ + def execute(): RDD[Row] + + /** + * Runs this query returning the result as an array. + */ + def executeCollect(): Array[Row] = execute().collect() + + protected def buildRow(values: Seq[Any]): Row = + new catalyst.expressions.GenericRow(values.toArray) +} + +/** + * Allows already planned SparkQueries to be linked into logical query plans. + * + * Note that in general it is not valid to use this class to link multiple copies of the same + * physical operator into the same query plan as this violates the uniqueness of expression ids. + * Special handling exists for ExistingRdd as these are already leaf operators and thus we can just + * replace the output attributes with new copies of themselves without breaking any attribute + * linking. + */ +case class SparkLogicalPlan(alreadyPlanned: SparkPlan) + extends logical.LogicalPlan with MultiInstanceRelation { + + def output = alreadyPlanned.output + def references = Set.empty + def children = Nil + + override final def newInstance: this.type = { + SparkLogicalPlan( + alreadyPlanned match { + case ExistingRdd(output, rdd) => ExistingRdd(output.map(_.newInstance), rdd) + case _ => sys.error("Multiple instance of the same relation detected.") + }).asInstanceOf[this.type] + } +} + +trait LeafNode extends SparkPlan with trees.LeafNode[SparkPlan] { + self: Product => +} + +trait UnaryNode extends SparkPlan with trees.UnaryNode[SparkPlan] { + self: Product => + override def outputPartitioning: Partitioning = child.outputPartitioning +} + +trait BinaryNode extends SparkPlan with trees.BinaryNode[SparkPlan] { + self: Product => +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala new file mode 100644 index 000000000..85035b811 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -0,0 +1,229 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package execution + +import org.apache.spark.SparkContext + +import catalyst.expressions._ +import catalyst.planning._ +import catalyst.plans._ +import catalyst.plans.logical.LogicalPlan +import catalyst.plans.physical._ +import parquet.ParquetRelation +import parquet.InsertIntoParquetTable + +abstract class SparkStrategies extends QueryPlanner[SparkPlan] { + + val sparkContext: SparkContext + + object SparkEquiInnerJoin extends Strategy { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case FilteredOperation(predicates, logical.Join(left, right, Inner, condition)) => + logger.debug(s"Considering join: ${predicates ++ condition}") + // Find equi-join predicates that can be evaluated before the join, and thus can be used + // as join keys. Note we can only mix in the conditions with other predicates because the + // match above ensures that this is and Inner join. + val (joinPredicates, otherPredicates) = (predicates ++ condition).partition { + case Equals(l, r) if (canEvaluate(l, left) && canEvaluate(r, right)) || + (canEvaluate(l, right) && canEvaluate(r, left)) => true + case _ => false + } + + val joinKeys = joinPredicates.map { + case Equals(l,r) if canEvaluate(l, left) && canEvaluate(r, right) => (l, r) + case Equals(l,r) if canEvaluate(l, right) && canEvaluate(r, left) => (r, l) + } + + // Do not consider this strategy if there are no join keys. + if (joinKeys.nonEmpty) { + val leftKeys = joinKeys.map(_._1) + val rightKeys = joinKeys.map(_._2) + + val joinOp = execution.SparkEquiInnerJoin( + leftKeys, rightKeys, planLater(left), planLater(right)) + + // Make sure other conditions are met if present. + if (otherPredicates.nonEmpty) { + execution.Filter(combineConjunctivePredicates(otherPredicates), joinOp) :: Nil + } else { + joinOp :: Nil + } + } else { + logger.debug(s"Avoiding spark join with no join keys.") + Nil + } + case _ => Nil + } + + private def combineConjunctivePredicates(predicates: Seq[Expression]) = + predicates.reduceLeft(And) + + /** Returns true if `expr` can be evaluated using only the output of `plan`. */ + protected def canEvaluate(expr: Expression, plan: LogicalPlan): Boolean = + expr.references subsetOf plan.outputSet + } + + object PartialAggregation extends Strategy { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case logical.Aggregate(groupingExpressions, aggregateExpressions, child) => + // Collect all aggregate expressions. + val allAggregates = + aggregateExpressions.flatMap(_ collect { case a: AggregateExpression => a}) + // Collect all aggregate expressions that can be computed partially. + val partialAggregates = + aggregateExpressions.flatMap(_ collect { case p: PartialAggregate => p}) + + // Only do partial aggregation if supported by all aggregate expressions. + if (allAggregates.size == partialAggregates.size) { + // Create a map of expressions to their partial evaluations for all aggregate expressions. + val partialEvaluations: Map[Long, SplitEvaluation] = + partialAggregates.map(a => (a.id, a.asPartial)).toMap + + // We need to pass all grouping expressions though so the grouping can happen a second + // time. However some of them might be unnamed so we alias them allowing them to be + // referenced in the second aggregation. + val namedGroupingExpressions: Map[Expression, NamedExpression] = groupingExpressions.map { + case n: NamedExpression => (n, n) + case other => (other, Alias(other, "PartialGroup")()) + }.toMap + + // Replace aggregations with a new expression that computes the result from the already + // computed partial evaluations and grouping values. + val rewrittenAggregateExpressions = aggregateExpressions.map(_.transformUp { + case e: Expression if partialEvaluations.contains(e.id) => + partialEvaluations(e.id).finalEvaluation + case e: Expression if namedGroupingExpressions.contains(e) => + namedGroupingExpressions(e).toAttribute + }).asInstanceOf[Seq[NamedExpression]] + + val partialComputation = + (namedGroupingExpressions.values ++ + partialEvaluations.values.flatMap(_.partialEvaluations)).toSeq + + // Construct two phased aggregation. + execution.Aggregate( + partial = false, + namedGroupingExpressions.values.map(_.toAttribute).toSeq, + rewrittenAggregateExpressions, + execution.Aggregate( + partial = true, + groupingExpressions, + partialComputation, + planLater(child))(sparkContext))(sparkContext) :: Nil + } else { + Nil + } + case _ => Nil + } + } + + object BroadcastNestedLoopJoin extends Strategy { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case logical.Join(left, right, joinType, condition) => + execution.BroadcastNestedLoopJoin( + planLater(left), planLater(right), joinType, condition)(sparkContext) :: Nil + case _ => Nil + } + } + + object CartesianProduct extends Strategy { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case logical.Join(left, right, _, None) => + execution.CartesianProduct(planLater(left), planLater(right)) :: Nil + case logical.Join(left, right, Inner, Some(condition)) => + execution.Filter(condition, + execution.CartesianProduct(planLater(left), planLater(right))) :: Nil + case _ => Nil + } + } + + protected lazy val singleRowRdd = + sparkContext.parallelize(Seq(new GenericRow(Array[Any]()): Row), 1) + + def convertToCatalyst(a: Any): Any = a match { + case s: Seq[Any] => s.map(convertToCatalyst) + case p: Product => new GenericRow(p.productIterator.map(convertToCatalyst).toArray) + case other => other + } + + object TopK extends Strategy { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case logical.StopAfter(IntegerLiteral(limit), logical.Sort(order, child)) => + execution.TopK(limit, order, planLater(child))(sparkContext) :: Nil + case _ => Nil + } + } + + // Can we automate these 'pass through' operations? + object BasicOperators extends Strategy { + // TOOD: Set + val numPartitions = 200 + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case logical.Distinct(child) => + execution.Aggregate( + partial = false, child.output, child.output, planLater(child))(sparkContext) :: Nil + case logical.Sort(sortExprs, child) => + // This sort is a global sort. Its requiredDistribution will be an OrderedDistribution. + execution.Sort(sortExprs, global = true, planLater(child)):: Nil + case logical.SortPartitions(sortExprs, child) => + // This sort only sorts tuples within a partition. Its requiredDistribution will be + // an UnspecifiedDistribution. + execution.Sort(sortExprs, global = false, planLater(child)) :: Nil + case logical.Project(projectList, r: ParquetRelation) + if projectList.forall(_.isInstanceOf[Attribute]) => + + // simple projection of data loaded from Parquet file + parquet.ParquetTableScan( + projectList.asInstanceOf[Seq[Attribute]], + r, + None)(sparkContext) :: Nil + case logical.Project(projectList, child) => + execution.Project(projectList, planLater(child)) :: Nil + case logical.Filter(condition, child) => + execution.Filter(condition, planLater(child)) :: Nil + case logical.Aggregate(group, agg, child) => + execution.Aggregate(partial = false, group, agg, planLater(child))(sparkContext) :: Nil + case logical.Sample(fraction, withReplacement, seed, child) => + execution.Sample(fraction, withReplacement, seed, planLater(child)) :: Nil + case logical.LocalRelation(output, data) => + val dataAsRdd = + sparkContext.parallelize(data.map(r => + new GenericRow(r.productIterator.map(convertToCatalyst).toArray): Row)) + execution.ExistingRdd(output, dataAsRdd) :: Nil + case logical.StopAfter(IntegerLiteral(limit), child) => + execution.StopAfter(limit, planLater(child))(sparkContext) :: Nil + case Unions(unionChildren) => + execution.Union(unionChildren.map(planLater))(sparkContext) :: Nil + case logical.Generate(generator, join, outer, _, child) => + execution.Generate(generator, join = join, outer = outer, planLater(child)) :: Nil + case logical.NoRelation => + execution.ExistingRdd(Nil, singleRowRdd) :: Nil + case logical.Repartition(expressions, child) => + execution.Exchange(HashPartitioning(expressions, numPartitions), planLater(child)) :: Nil + case logical.WriteToFile(path, child) => + val relation = + ParquetRelation.create(path, child, sparkContext.hadoopConfiguration, None) + InsertIntoParquetTable(relation, planLater(child))(sparkContext) :: Nil + case p: parquet.ParquetRelation => + parquet.ParquetTableScan(p.output, p, None)(sparkContext) :: Nil + case SparkLogicalPlan(existingPlan) => existingPlan :: Nil + case _ => Nil + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregates.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregates.scala new file mode 100644 index 000000000..51889c198 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregates.scala @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package execution + +import org.apache.spark.SparkContext + +import catalyst.errors._ +import catalyst.expressions._ +import catalyst.plans.physical.{UnspecifiedDistribution, ClusteredDistribution, AllTuples} +import catalyst.types._ + +import org.apache.spark.rdd.PartitionLocalRDDFunctions._ + +/** + * Groups input data by `groupingExpressions` and computes the `aggregateExpressions` for each + * group. + * + * @param partial if true then aggregation is done partially on local data without shuffling to + * ensure all values where `groupingExpressions` are equal are present. + * @param groupingExpressions expressions that are evaluated to determine grouping. + * @param aggregateExpressions expressions that are computed for each group. + * @param child the input data source. + */ +case class Aggregate( + partial: Boolean, + groupingExpressions: Seq[Expression], + aggregateExpressions: Seq[NamedExpression], + child: SparkPlan)(@transient sc: SparkContext) + extends UnaryNode { + + override def requiredChildDistribution = + if (partial) { + UnspecifiedDistribution :: Nil + } else { + if (groupingExpressions == Nil) { + AllTuples :: Nil + } else { + ClusteredDistribution(groupingExpressions) :: Nil + } + } + + override def otherCopyArgs = sc :: Nil + + def output = aggregateExpressions.map(_.toAttribute) + + /* Replace all aggregate expressions with spark functions that will compute the result. */ + def createAggregateImplementations() = aggregateExpressions.map { agg => + val impl = agg transform { + case a: AggregateExpression => a.newInstance + } + + val remainingAttributes = impl.collect { case a: Attribute => a } + // If any references exist that are not inside agg functions then the must be grouping exprs + // in this case we must rebind them to the grouping tuple. + if (remainingAttributes.nonEmpty) { + val unaliasedAggregateExpr = agg transform { case Alias(c, _) => c } + + // An exact match with a grouping expression + val exactGroupingExpr = groupingExpressions.indexOf(unaliasedAggregateExpr) match { + case -1 => None + case ordinal => Some(BoundReference(ordinal, Alias(impl, "AGGEXPR")().toAttribute)) + } + + exactGroupingExpr.getOrElse( + sys.error(s"$agg is not in grouping expressions: $groupingExpressions")) + } else { + impl + } + } + + def execute() = attachTree(this, "execute") { + // TODO: If the child of it is an [[catalyst.execution.Exchange]], + // do not evaluate the groupingExpressions again since we have evaluated it + // in the [[catalyst.execution.Exchange]]. + val grouped = child.execute().mapPartitions { iter => + val buildGrouping = new Projection(groupingExpressions) + iter.map(row => (buildGrouping(row), row.copy())) + }.groupByKeyLocally() + + val result = grouped.map { case (group, rows) => + val aggImplementations = createAggregateImplementations() + + // Pull out all the functions so we can feed each row into them. + val aggFunctions = aggImplementations.flatMap(_ collect { case f: AggregateFunction => f }) + + rows.foreach { row => + aggFunctions.foreach(_.update(row)) + } + buildRow(aggImplementations.map(_.apply(group))) + } + + // TODO: THIS BREAKS PIPELINING, DOUBLE COMPUTES THE ANSWER, AND USES TOO MUCH MEMORY... + if (groupingExpressions.isEmpty && result.count == 0) { + // When there there is no output to the Aggregate operator, we still output an empty row. + val aggImplementations = createAggregateImplementations() + sc.makeRDD(buildRow(aggImplementations.map(_.apply(null))) :: Nil) + } else { + result + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala new file mode 100644 index 000000000..c6d31d9ab --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package execution + +import scala.reflect.runtime.universe.TypeTag + +import org.apache.spark.rdd.RDD +import org.apache.spark.SparkContext + +import catalyst.errors._ +import catalyst.expressions._ +import catalyst.plans.physical.{UnspecifiedDistribution, OrderedDistribution} +import catalyst.plans.logical.LogicalPlan +import catalyst.ScalaReflection + +case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends UnaryNode { + def output = projectList.map(_.toAttribute) + + def execute() = child.execute().mapPartitions { iter => + @transient val resuableProjection = new MutableProjection(projectList) + iter.map(resuableProjection) + } +} + +case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode { + def output = child.output + + def execute() = child.execute().mapPartitions { iter => + iter.filter(condition.apply(_).asInstanceOf[Boolean]) + } +} + +case class Sample(fraction: Double, withReplacement: Boolean, seed: Int, child: SparkPlan) + extends UnaryNode { + + def output = child.output + + // TODO: How to pick seed? + def execute() = child.execute().sample(withReplacement, fraction, seed) +} + +case class Union(children: Seq[SparkPlan])(@transient sc: SparkContext) extends SparkPlan { + // TODO: attributes output by union should be distinct for nullability purposes + def output = children.head.output + def execute() = sc.union(children.map(_.execute())) + + override def otherCopyArgs = sc :: Nil +} + +case class StopAfter(limit: Int, child: SparkPlan)(@transient sc: SparkContext) extends UnaryNode { + override def otherCopyArgs = sc :: Nil + + def output = child.output + + override def executeCollect() = child.execute().map(_.copy()).take(limit) + + // TODO: Terminal split should be implemented differently from non-terminal split. + // TODO: Pick num splits based on |limit|. + def execute() = sc.makeRDD(executeCollect(), 1) +} + +case class TopK(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan) + (@transient sc: SparkContext) extends UnaryNode { + override def otherCopyArgs = sc :: Nil + + def output = child.output + + @transient + lazy val ordering = new RowOrdering(sortOrder) + + override def executeCollect() = child.execute().map(_.copy()).takeOrdered(limit)(ordering) + + // TODO: Terminal split should be implemented differently from non-terminal split. + // TODO: Pick num splits based on |limit|. + def execute() = sc.makeRDD(executeCollect(), 1) +} + + +case class Sort( + sortOrder: Seq[SortOrder], + global: Boolean, + child: SparkPlan) + extends UnaryNode { + override def requiredChildDistribution = + if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil + + @transient + lazy val ordering = new RowOrdering(sortOrder) + + def execute() = attachTree(this, "sort") { + // TODO: Optimize sorting operation? + child.execute() + .mapPartitions( + iterator => iterator.map(_.copy()).toArray.sorted(ordering).iterator, + preservesPartitioning = true) + } + + def output = child.output +} + +object ExistingRdd { + def convertToCatalyst(a: Any): Any = a match { + case s: Seq[Any] => s.map(convertToCatalyst) + case p: Product => new GenericRow(p.productIterator.map(convertToCatalyst).toArray) + case other => other + } + + def productToRowRdd[A <: Product](data: RDD[A]): RDD[Row] = { + // TODO: Reuse the row, don't use map on the product iterator. Maybe code gen? + data.map(r => new GenericRow(r.productIterator.map(convertToCatalyst).toArray): Row) + } + + def fromProductRdd[A <: Product : TypeTag](productRdd: RDD[A]) = { + ExistingRdd(ScalaReflection.attributesFor[A], productToRowRdd(productRdd)) + } +} + +case class ExistingRdd(output: Seq[Attribute], rdd: RDD[Row]) extends LeafNode { + def execute() = rdd +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug.scala new file mode 100644 index 000000000..db259b4c4 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package execution + +object DebugQuery { + def apply(plan: SparkPlan): SparkPlan = { + val visited = new collection.mutable.HashSet[Long]() + plan transform { + case s: SparkPlan if !visited.contains(s.id) => + visited += s.id + DebugNode(s) + } + } +} + +case class DebugNode(child: SparkPlan) extends UnaryNode { + def references = Set.empty + def output = child.output + def execute() = { + val childRdd = child.execute() + println( + s""" + |========================= + |${child.simpleString} + |========================= + """.stripMargin) + childRdd.foreach(println(_)) + childRdd + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala new file mode 100644 index 000000000..5934fd1b0 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package execution + +import scala.collection.mutable + +import org.apache.spark.rdd.RDD +import org.apache.spark.SparkContext + +import catalyst.errors._ +import catalyst.expressions._ +import catalyst.plans._ +import catalyst.plans.physical.{ClusteredDistribution, Partitioning} + +import org.apache.spark.rdd.PartitionLocalRDDFunctions._ + +case class SparkEquiInnerJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + left: SparkPlan, + right: SparkPlan) extends BinaryNode { + + override def outputPartitioning: Partitioning = left.outputPartitioning + + override def requiredChildDistribution = + ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + + def output = left.output ++ right.output + + def execute() = attachTree(this, "execute") { + val leftWithKeys = left.execute().mapPartitions { iter => + val generateLeftKeys = new Projection(leftKeys, left.output) + iter.map(row => (generateLeftKeys(row), row.copy())) + } + + val rightWithKeys = right.execute().mapPartitions { iter => + val generateRightKeys = new Projection(rightKeys, right.output) + iter.map(row => (generateRightKeys(row), row.copy())) + } + + // Do the join. + val joined = filterNulls(leftWithKeys).joinLocally(filterNulls(rightWithKeys)) + // Drop join keys and merge input tuples. + joined.map { case (_, (leftTuple, rightTuple)) => buildRow(leftTuple ++ rightTuple) } + } + + /** + * Filters any rows where the any of the join keys is null, ensuring three-valued + * logic for the equi-join conditions. + */ + protected def filterNulls(rdd: RDD[(Row, Row)]) = + rdd.filter { + case (key: Seq[_], _) => !key.exists(_ == null) + } +} + +case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNode { + def output = left.output ++ right.output + + def execute() = left.execute().map(_.copy()).cartesian(right.execute().map(_.copy())).map { + case (l: Row, r: Row) => buildRow(l ++ r) + } +} + +case class BroadcastNestedLoopJoin( + streamed: SparkPlan, broadcast: SparkPlan, joinType: JoinType, condition: Option[Expression]) + (@transient sc: SparkContext) + extends BinaryNode { + // TODO: Override requiredChildDistribution. + + override def outputPartitioning: Partitioning = streamed.outputPartitioning + + override def otherCopyArgs = sc :: Nil + + def output = left.output ++ right.output + + /** The Streamed Relation */ + def left = streamed + /** The Broadcast relation */ + def right = broadcast + + @transient lazy val boundCondition = + condition + .map(c => BindReferences.bindReference(c, left.output ++ right.output)) + .getOrElse(Literal(true)) + + + def execute() = { + val broadcastedRelation = sc.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) + + val streamedPlusMatches = streamed.execute().mapPartitions { streamedIter => + val matchedRows = new mutable.ArrayBuffer[Row] + val includedBroadcastTuples = new mutable.BitSet(broadcastedRelation.value.size) + val joinedRow = new JoinedRow + + streamedIter.foreach { streamedRow => + var i = 0 + var matched = false + + while (i < broadcastedRelation.value.size) { + // TODO: One bitset per partition instead of per row. + val broadcastedRow = broadcastedRelation.value(i) + if (boundCondition(joinedRow(streamedRow, broadcastedRow)).asInstanceOf[Boolean]) { + matchedRows += buildRow(streamedRow ++ broadcastedRow) + matched = true + includedBroadcastTuples += i + } + i += 1 + } + + if (!matched && (joinType == LeftOuter || joinType == FullOuter)) { + matchedRows += buildRow(streamedRow ++ Array.fill(right.output.size)(null)) + } + } + Iterator((matchedRows, includedBroadcastTuples)) + } + + val includedBroadcastTuples = streamedPlusMatches.map(_._2) + val allIncludedBroadcastTuples = + if (includedBroadcastTuples.count == 0) { + new scala.collection.mutable.BitSet(broadcastedRelation.value.size) + } else { + streamedPlusMatches.map(_._2).reduce(_ ++ _) + } + + val rightOuterMatches: Seq[Row] = + if (joinType == RightOuter || joinType == FullOuter) { + broadcastedRelation.value.zipWithIndex.filter { + case (row, i) => !allIncludedBroadcastTuples.contains(i) + }.map { + // TODO: Use projection. + case (row, _) => buildRow(Vector.fill(left.output.size)(null) ++ row) + } + } else { + Vector() + } + + // TODO: Breaks lineage. + sc.union( + streamedPlusMatches.flatMap(_._1), sc.makeRDD(rightOuterMatches)) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/package.scala new file mode 100644 index 000000000..67f6f43f9 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/package.scala @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +/** + * An execution engine for relational query plans that runs on top Spark and returns RDDs. + * + * Note that the operators in this package are created automatically by a query planner using a + * [[SQLContext]] and are not intended to be used directly by end users of Spark SQL. They are + * documented here in order to make it easier for others to understand the performance + * characteristics of query plans that are generated by Spark SQL. + */ +package object execution { +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala new file mode 100644 index 000000000..e87561fe1 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala @@ -0,0 +1,276 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.parquet + +import java.io.{IOException, FileNotFoundException} + +import org.apache.hadoop.fs.{Path, FileSystem} +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.mapreduce.Job +import org.apache.hadoop.fs.permission.FsAction + +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, BaseRelation} +import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.catalyst.types.ArrayType +import org.apache.spark.sql.catalyst.expressions.{Row, AttributeReference, Attribute} +import org.apache.spark.sql.catalyst.analysis.UnresolvedException + +import parquet.schema.{MessageTypeParser, MessageType} +import parquet.schema.PrimitiveType.{PrimitiveTypeName => ParquetPrimitiveTypeName} +import parquet.schema.{PrimitiveType => ParquetPrimitiveType} +import parquet.schema.{Type => ParquetType} +import parquet.schema.Type.Repetition +import parquet.io.api.{Binary, RecordConsumer} +import parquet.hadoop.{Footer, ParquetFileWriter, ParquetFileReader} +import parquet.hadoop.metadata.{FileMetaData, ParquetMetadata} +import parquet.hadoop.util.ContextUtil + +import scala.collection.JavaConversions._ + +/** + * Relation that consists of data stored in a Parquet columnar format. + * + * Users should interact with parquet files though a SchemaRDD, created by a [[SQLContext]] instead + * of using this class directly. + * + * {{{ + * val parquetRDD = sqlContext.parquetFile("path/to/parequet.file") + * }}} + * + * @param tableName The name of the relation that can be used in queries. + * @param path The path to the Parquet file. + */ +case class ParquetRelation(val tableName: String, val path: String) extends BaseRelation { + + /** Schema derived from ParquetFile **/ + def parquetSchema: MessageType = + ParquetTypesConverter + .readMetaData(new Path(path)) + .getFileMetaData + .getSchema + + /** Attributes **/ + val attributes = + ParquetTypesConverter + .convertToAttributes(parquetSchema) + + /** Output **/ + override val output = attributes + + // Parquet files have no concepts of keys, therefore no Partitioner + // Note: we could allow Block level access; needs to be thought through + override def isPartitioned = false +} + +object ParquetRelation { + + // The element type for the RDDs that this relation maps to. + type RowType = org.apache.spark.sql.catalyst.expressions.GenericMutableRow + + /** + * Creates a new ParquetRelation and underlying Parquetfile for the given + * LogicalPlan. Note that this is used inside [[SparkStrategies]] to + * create a resolved relation as a data sink for writing to a Parquetfile. + * The relation is empty but is initialized with ParquetMetadata and + * can be inserted into. + * + * @param pathString The directory the Parquetfile will be stored in. + * @param child The child node that will be used for extracting the schema. + * @param conf A configuration configuration to be used. + * @param tableName The name of the resulting relation. + * @return An empty ParquetRelation inferred metadata. + */ + def create(pathString: String, + child: LogicalPlan, + conf: Configuration, + tableName: Option[String]): ParquetRelation = { + if (!child.resolved) { + throw new UnresolvedException[LogicalPlan]( + child, + "Attempt to create Parquet table from unresolved child (when schema is not available)") + } + + val name = s"${tableName.getOrElse(child.nodeName)}_parquet" + val path = checkPath(pathString, conf) + ParquetTypesConverter.writeMetaData(child.output, path, conf) + new ParquetRelation(name, path.toString) + } + + private def checkPath(pathStr: String, conf: Configuration): Path = { + if (pathStr == null) { + throw new IllegalArgumentException("Unable to create ParquetRelation: path is null") + } + val origPath = new Path(pathStr) + val fs = origPath.getFileSystem(conf) + if (fs == null) { + throw new IllegalArgumentException( + s"Unable to create ParquetRelation: incorrectly formatted path $pathStr") + } + val path = origPath.makeQualified(fs) + if (fs.exists(path) && + !fs.getFileStatus(path) + .getPermission + .getUserAction + .implies(FsAction.READ_WRITE)) { + throw new IOException( + s"Unable to create ParquetRelation: path $path not read-writable") + } + path + } +} + +object ParquetTypesConverter { + def toDataType(parquetType : ParquetPrimitiveTypeName): DataType = parquetType match { + // for now map binary to string type + // TODO: figure out how Parquet uses strings or why we can't use them in a MessageType schema + case ParquetPrimitiveTypeName.BINARY => StringType + case ParquetPrimitiveTypeName.BOOLEAN => BooleanType + case ParquetPrimitiveTypeName.DOUBLE => DoubleType + case ParquetPrimitiveTypeName.FIXED_LEN_BYTE_ARRAY => ArrayType(ByteType) + case ParquetPrimitiveTypeName.FLOAT => FloatType + case ParquetPrimitiveTypeName.INT32 => IntegerType + case ParquetPrimitiveTypeName.INT64 => LongType + case ParquetPrimitiveTypeName.INT96 => { + // TODO: add BigInteger type? TODO(andre) use DecimalType instead???? + sys.error("Warning: potential loss of precision: converting INT96 to long") + LongType + } + case _ => sys.error( + s"Unsupported parquet datatype $parquetType") + } + + def fromDataType(ctype: DataType): ParquetPrimitiveTypeName = ctype match { + case StringType => ParquetPrimitiveTypeName.BINARY + case BooleanType => ParquetPrimitiveTypeName.BOOLEAN + case DoubleType => ParquetPrimitiveTypeName.DOUBLE + case ArrayType(ByteType) => ParquetPrimitiveTypeName.FIXED_LEN_BYTE_ARRAY + case FloatType => ParquetPrimitiveTypeName.FLOAT + case IntegerType => ParquetPrimitiveTypeName.INT32 + case LongType => ParquetPrimitiveTypeName.INT64 + case _ => sys.error(s"Unsupported datatype $ctype") + } + + def consumeType(consumer: RecordConsumer, ctype: DataType, record: Row, index: Int): Unit = { + ctype match { + case StringType => consumer.addBinary( + Binary.fromByteArray( + record(index).asInstanceOf[String].getBytes("utf-8") + ) + ) + case IntegerType => consumer.addInteger(record.getInt(index)) + case LongType => consumer.addLong(record.getLong(index)) + case DoubleType => consumer.addDouble(record.getDouble(index)) + case FloatType => consumer.addFloat(record.getFloat(index)) + case BooleanType => consumer.addBoolean(record.getBoolean(index)) + case _ => sys.error(s"Unsupported datatype $ctype, cannot write to consumer") + } + } + + def getSchema(schemaString : String) : MessageType = + MessageTypeParser.parseMessageType(schemaString) + + def convertToAttributes(parquetSchema: MessageType) : Seq[Attribute] = { + parquetSchema.getColumns.map { + case (desc) => { + val ctype = toDataType(desc.getType) + val name: String = desc.getPath.mkString(".") + new AttributeReference(name, ctype, false)() + } + } + } + + // TODO: allow nesting? + def convertFromAttributes(attributes: Seq[Attribute]): MessageType = { + val fields: Seq[ParquetType] = attributes.map { + a => new ParquetPrimitiveType(Repetition.OPTIONAL, fromDataType(a.dataType), a.name) + } + new MessageType("root", fields) + } + + def writeMetaData(attributes: Seq[Attribute], origPath: Path, conf: Configuration) { + if (origPath == null) { + throw new IllegalArgumentException("Unable to write Parquet metadata: path is null") + } + val fs = origPath.getFileSystem(conf) + if (fs == null) { + throw new IllegalArgumentException( + s"Unable to write Parquet metadata: path $origPath is incorrectly formatted") + } + val path = origPath.makeQualified(fs) + if (fs.exists(path) && !fs.getFileStatus(path).isDir) { + throw new IllegalArgumentException(s"Expected to write to directory $path but found file") + } + val metadataPath = new Path(path, ParquetFileWriter.PARQUET_METADATA_FILE) + if (fs.exists(metadataPath)) { + try { + fs.delete(metadataPath, true) + } catch { + case e: IOException => + throw new IOException(s"Unable to delete previous PARQUET_METADATA_FILE at $metadataPath") + } + } + val extraMetadata = new java.util.HashMap[String, String]() + extraMetadata.put("path", path.toString) + // TODO: add extra data, e.g., table name, date, etc.? + + val parquetSchema: MessageType = + ParquetTypesConverter.convertFromAttributes(attributes) + val metaData: FileMetaData = new FileMetaData( + parquetSchema, + extraMetadata, + "Spark") + + ParquetFileWriter.writeMetadataFile( + conf, + path, + new Footer(path, new ParquetMetadata(metaData, Nil)) :: Nil) + } + + /** + * Try to read Parquet metadata at the given Path. We first see if there is a summary file + * in the parent directory. If so, this is used. Else we read the actual footer at the given + * location. + * @param path The path at which we expect one (or more) Parquet files. + * @return The `ParquetMetadata` containing among other things the schema. + */ + def readMetaData(origPath: Path): ParquetMetadata = { + if (origPath == null) { + throw new IllegalArgumentException("Unable to read Parquet metadata: path is null") + } + val job = new Job() + // TODO: since this is called from ParquetRelation (LogicalPlan) we don't have access + // to SparkContext's hadoopConfig; in principle the default FileSystem may be different(?!) + val conf = ContextUtil.getConfiguration(job) + val fs: FileSystem = origPath.getFileSystem(conf) + if (fs == null) { + throw new IllegalArgumentException(s"Incorrectly formatted Parquet metadata path $origPath") + } + val path = origPath.makeQualified(fs) + val metadataPath = new Path(path, ParquetFileWriter.PARQUET_METADATA_FILE) + if (fs.exists(metadataPath) && fs.isFile(metadataPath)) { + // TODO: improve exception handling, etc. + ParquetFileReader.readFooter(conf, metadataPath) + } else { + if (!fs.exists(path) || !fs.isFile(path)) { + throw new FileNotFoundException( + s"Could not find file ${path.toString} when trying to read metadata") + } + ParquetFileReader.readFooter(conf, path) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala new file mode 100644 index 000000000..61121103c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -0,0 +1,212 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.parquet + +import parquet.io.InvalidRecordException +import parquet.schema.MessageType +import parquet.hadoop.{ParquetOutputFormat, ParquetInputFormat} +import parquet.hadoop.util.ContextUtil + +import org.apache.spark.rdd.RDD +import org.apache.spark.{TaskContext, SerializableWritable, SparkContext} +import org.apache.spark.sql.catalyst.expressions.{Row, Attribute, Expression} +import org.apache.spark.sql.execution.{SparkPlan, UnaryNode, LeafNode} + +import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat => NewFileOutputFormat} +import org.apache.hadoop.mapreduce._ +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path + +import java.io.IOException +import java.text.SimpleDateFormat +import java.util.Date + +/** + * Parquet table scan operator. Imports the file that backs the given + * [[ParquetRelation]] as a RDD[Row]. + */ +case class ParquetTableScan( + @transient output: Seq[Attribute], + @transient relation: ParquetRelation, + @transient columnPruningPred: Option[Expression])( + @transient val sc: SparkContext) + extends LeafNode { + + override def execute(): RDD[Row] = { + val job = new Job(sc.hadoopConfiguration) + ParquetInputFormat.setReadSupportClass( + job, + classOf[org.apache.spark.sql.parquet.RowReadSupport]) + val conf: Configuration = ContextUtil.getConfiguration(job) + conf.set( + RowReadSupport.PARQUET_ROW_REQUESTED_SCHEMA, + ParquetTypesConverter.convertFromAttributes(output).toString) + // TODO: think about adding record filters + /* Comments regarding record filters: it would be nice to push down as much filtering + to Parquet as possible. However, currently it seems we cannot pass enough information + to materialize an (arbitrary) Catalyst [[Predicate]] inside Parquet's + ``FilteredRecordReader`` (via Configuration, for example). Simple + filter-rows-by-column-values however should be supported. + */ + sc.newAPIHadoopFile( + relation.path, + classOf[ParquetInputFormat[Row]], + classOf[Void], classOf[Row], + conf) + .map(_._2) + } + + /** + * Applies a (candidate) projection. + * + * @param prunedAttributes The list of attributes to be used in the projection. + * @return Pruned TableScan. + */ + def pruneColumns(prunedAttributes: Seq[Attribute]): ParquetTableScan = { + val success = validateProjection(prunedAttributes) + if (success) { + ParquetTableScan(prunedAttributes, relation, columnPruningPred)(sc) + } else { + sys.error("Warning: Could not validate Parquet schema projection in pruneColumns") + this + } + } + + /** + * Evaluates a candidate projection by checking whether the candidate is a subtype + * of the original type. + * + * @param projection The candidate projection. + * @return True if the projection is valid, false otherwise. + */ + private def validateProjection(projection: Seq[Attribute]): Boolean = { + val original: MessageType = relation.parquetSchema + val candidate: MessageType = ParquetTypesConverter.convertFromAttributes(projection) + try { + original.checkContains(candidate) + true + } catch { + case e: InvalidRecordException => { + false + } + } + } +} + +case class InsertIntoParquetTable( + @transient relation: ParquetRelation, + @transient child: SparkPlan)( + @transient val sc: SparkContext) + extends UnaryNode with SparkHadoopMapReduceUtil { + + /** + * Inserts all the rows in the Parquet file. Note that OVERWRITE is implicit, since + * Parquet files are write-once. + */ + override def execute() = { + // TODO: currently we do not check whether the "schema"s are compatible + // That means if one first creates a table and then INSERTs data with + // and incompatible schema the execution will fail. It would be nice + // to catch this early one, maybe having the planner validate the schema + // before calling execute(). + + val childRdd = child.execute() + assert(childRdd != null) + + val job = new Job(sc.hadoopConfiguration) + + ParquetOutputFormat.setWriteSupportClass( + job, + classOf[org.apache.spark.sql.parquet.RowWriteSupport]) + + // TODO: move that to function in object + val conf = job.getConfiguration + conf.set(RowWriteSupport.PARQUET_ROW_SCHEMA, relation.parquetSchema.toString) + + val fspath = new Path(relation.path) + val fs = fspath.getFileSystem(conf) + + try { + fs.delete(fspath, true) + } catch { + case e: IOException => + throw new IOException( + s"Unable to clear output directory ${fspath.toString} prior" + + s" to InsertIntoParquetTable:\n${e.toString}") + } + saveAsHadoopFile(childRdd, relation.path.toString, conf) + + // We return the child RDD to allow chaining (alternatively, one could return nothing). + childRdd + } + + override def output = child.output + + // based on ``saveAsNewAPIHadoopFile`` in [[PairRDDFunctions]] + // TODO: Maybe PairRDDFunctions should use Product2 instead of Tuple2? + // .. then we could use the default one and could use [[MutablePair]] + // instead of ``Tuple2`` + private def saveAsHadoopFile( + rdd: RDD[Row], + path: String, + conf: Configuration) { + val job = new Job(conf) + val keyType = classOf[Void] + val outputFormatType = classOf[parquet.hadoop.ParquetOutputFormat[Row]] + job.setOutputKeyClass(keyType) + job.setOutputValueClass(classOf[Row]) + val wrappedConf = new SerializableWritable(job.getConfiguration) + NewFileOutputFormat.setOutputPath(job, new Path(path)) + val formatter = new SimpleDateFormat("yyyyMMddHHmm") + val jobtrackerID = formatter.format(new Date()) + val stageId = sc.newRddId() + + def writeShard(context: TaskContext, iter: Iterator[Row]): Int = { + // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it + // around by taking a mod. We expect that no task will be attempted 2 billion times. + val attemptNumber = (context.attemptId % Int.MaxValue).toInt + /* "reduce task" */ + val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.partitionId, + attemptNumber) + val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId) + val format = outputFormatType.newInstance + val committer = format.getOutputCommitter(hadoopContext) + committer.setupTask(hadoopContext) + val writer = format.getRecordWriter(hadoopContext) + while (iter.hasNext) { + val row = iter.next() + writer.write(null, row) + } + writer.close(hadoopContext) + committer.commitTask(hadoopContext) + return 1 + } + val jobFormat = outputFormatType.newInstance + /* apparently we need a TaskAttemptID to construct an OutputCommitter; + * however we're only going to use this local OutputCommitter for + * setupJob/commitJob, so we just use a dummy "map" task. + */ + val jobAttemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = true, 0, 0) + val jobTaskContext = newTaskAttemptContext(wrappedConf.value, jobAttemptId) + val jobCommitter = jobFormat.getOutputCommitter(jobTaskContext) + jobCommitter.setupJob(jobTaskContext) + sc.runJob(rdd, writeShard _) + jobCommitter.commitJob(jobTaskContext) + } +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala new file mode 100644 index 000000000..c2ae18b88 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -0,0 +1,220 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.parquet + +import org.apache.hadoop.conf.Configuration + +import org.apache.spark.Logging + +import parquet.io.api._ +import parquet.schema.{MessageTypeParser, MessageType} +import parquet.hadoop.api.{WriteSupport, ReadSupport} +import parquet.hadoop.api.ReadSupport.ReadContext +import parquet.hadoop.ParquetOutputFormat +import parquet.column.ParquetProperties + +import org.apache.spark.sql.catalyst.expressions.{Row, Attribute} +import org.apache.spark.sql.catalyst.types._ + +/** + * A `parquet.io.api.RecordMaterializer` for Rows. + * + *@param root The root group converter for the record. + */ +class RowRecordMaterializer(root: CatalystGroupConverter) extends RecordMaterializer[Row] { + + def this(parquetSchema: MessageType) = + this(new CatalystGroupConverter(ParquetTypesConverter.convertToAttributes(parquetSchema))) + + override def getCurrentRecord: Row = root.getCurrentRecord + + override def getRootConverter: GroupConverter = root +} + +/** + * A `parquet.hadoop.api.ReadSupport` for Row objects. + */ +class RowReadSupport extends ReadSupport[Row] with Logging { + + override def prepareForRead( + conf: Configuration, + stringMap: java.util.Map[String, String], + fileSchema: MessageType, + readContext: ReadContext): RecordMaterializer[Row] = { + log.debug(s"preparing for read with schema ${fileSchema.toString}") + new RowRecordMaterializer(readContext.getRequestedSchema) + } + + override def init( + configuration: Configuration, + keyValueMetaData: java.util.Map[String, String], + fileSchema: MessageType): ReadContext = { + val requested_schema_string = + configuration.get(RowReadSupport.PARQUET_ROW_REQUESTED_SCHEMA, fileSchema.toString) + val requested_schema = + MessageTypeParser.parseMessageType(requested_schema_string) + + log.debug(s"read support initialized for original schema ${requested_schema.toString}") + new ReadContext(requested_schema, keyValueMetaData) + } +} + +object RowReadSupport { + val PARQUET_ROW_REQUESTED_SCHEMA = "org.apache.spark.sql.parquet.row.requested_schema" +} + +/** + * A `parquet.hadoop.api.WriteSupport` for Row ojects. + */ +class RowWriteSupport extends WriteSupport[Row] with Logging { + def setSchema(schema: MessageType, configuration: Configuration) { + // for testing + this.schema = schema + // TODO: could use Attributes themselves instead of Parquet schema? + configuration.set( + RowWriteSupport.PARQUET_ROW_SCHEMA, + schema.toString) + configuration.set( + ParquetOutputFormat.WRITER_VERSION, + ParquetProperties.WriterVersion.PARQUET_1_0.toString) + } + + def getSchema(configuration: Configuration): MessageType = { + return MessageTypeParser.parseMessageType( + configuration.get(RowWriteSupport.PARQUET_ROW_SCHEMA)) + } + + private var schema: MessageType = null + private var writer: RecordConsumer = null + private var attributes: Seq[Attribute] = null + + override def init(configuration: Configuration): WriteSupport.WriteContext = { + schema = if (schema == null) getSchema(configuration) else schema + attributes = ParquetTypesConverter.convertToAttributes(schema) + new WriteSupport.WriteContext( + schema, + new java.util.HashMap[java.lang.String, java.lang.String]()); + } + + override def prepareForWrite(recordConsumer: RecordConsumer): Unit = { + writer = recordConsumer + } + + // TODO: add groups (nested fields) + override def write(record: Row): Unit = { + var index = 0 + writer.startMessage() + while(index < attributes.size) { + // null values indicate optional fields but we do not check currently + if (record(index) != null && record(index) != Nil) { + writer.startField(attributes(index).name, index) + ParquetTypesConverter.consumeType(writer, attributes(index).dataType, record, index) + writer.endField(attributes(index).name, index) + } + index = index + 1 + } + writer.endMessage() + } +} + +object RowWriteSupport { + val PARQUET_ROW_SCHEMA: String = "org.apache.spark.sql.parquet.row.schema" +} + +/** + * A `parquet.io.api.GroupConverter` that is able to convert a Parquet record + * to a [[org.apache.spark.sql.catalyst.expressions.Row]] object. + * + * @param schema The corresponding Catalyst schema in the form of a list of attributes. + */ +class CatalystGroupConverter( + schema: Seq[Attribute], + protected[parquet] val current: ParquetRelation.RowType) extends GroupConverter { + + def this(schema: Seq[Attribute]) = this(schema, new ParquetRelation.RowType(schema.length)) + + val converters: Array[Converter] = schema.map { + a => a.dataType match { + case ctype: NativeType => + // note: for some reason matching for StringType fails so use this ugly if instead + if (ctype == StringType) new CatalystPrimitiveStringConverter(this, schema.indexOf(a)) + else new CatalystPrimitiveConverter(this, schema.indexOf(a)) + case _ => throw new RuntimeException( + s"unable to convert datatype ${a.dataType.toString} in CatalystGroupConverter") + } + }.toArray + + override def getConverter(fieldIndex: Int): Converter = converters(fieldIndex) + + private[parquet] def getCurrentRecord: ParquetRelation.RowType = current + + override def start(): Unit = { + var i = 0 + while (i < schema.length) { + current.setNullAt(i) + i = i + 1 + } + } + + override def end(): Unit = {} +} + +/** + * A `parquet.io.api.PrimitiveConverter` that converts Parquet types to Catalyst types. + * + * @param parent The parent group converter. + * @param fieldIndex The index inside the record. + */ +class CatalystPrimitiveConverter( + parent: CatalystGroupConverter, + fieldIndex: Int) extends PrimitiveConverter { + // TODO: consider refactoring these together with ParquetTypesConverter + override def addBinary(value: Binary): Unit = + // TODO: fix this once a setBinary will become available in MutableRow + parent.getCurrentRecord.setByte(fieldIndex, value.getBytes.apply(0)) + + override def addBoolean(value: Boolean): Unit = + parent.getCurrentRecord.setBoolean(fieldIndex, value) + + override def addDouble(value: Double): Unit = + parent.getCurrentRecord.setDouble(fieldIndex, value) + + override def addFloat(value: Float): Unit = + parent.getCurrentRecord.setFloat(fieldIndex, value) + + override def addInt(value: Int): Unit = + parent.getCurrentRecord.setInt(fieldIndex, value) + + override def addLong(value: Long): Unit = + parent.getCurrentRecord.setLong(fieldIndex, value) +} + +/** + * A `parquet.io.api.PrimitiveConverter` that converts Parquet strings (fixed-length byte arrays) + * into Catalyst Strings. + * + * @param parent The parent group converter. + * @param fieldIndex The index inside the record. + */ +class CatalystPrimitiveStringConverter( + parent: CatalystGroupConverter, + fieldIndex: Int) extends CatalystPrimitiveConverter(parent, fieldIndex) { + override def addBinary(value: Binary): Unit = + parent.getCurrentRecord.setString(fieldIndex, value.toStringUsingUTF8) +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala new file mode 100644 index 000000000..bbe409fb9 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.parquet + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.mapreduce.Job + +import parquet.schema.{MessageTypeParser, MessageType} +import parquet.hadoop.util.ContextUtil +import parquet.hadoop.ParquetWriter + +import org.apache.spark.sql.catalyst.util.getTempFilePath +import org.apache.spark.sql.catalyst.expressions.GenericRow +import java.nio.charset.Charset + +object ParquetTestData { + + val testSchema = + """message myrecord { + |optional boolean myboolean; + |optional int32 myint; + |optional binary mystring; + |optional int64 mylong; + |optional float myfloat; + |optional double mydouble; + |}""".stripMargin + + // field names for test assertion error messages + val testSchemaFieldNames = Seq( + "myboolean:Boolean", + "mtint:Int", + "mystring:String", + "mylong:Long", + "myfloat:Float", + "mydouble:Double" + ) + + val subTestSchema = + """ + |message myrecord { + |optional boolean myboolean; + |optional int64 mylong; + |} + """.stripMargin + + // field names for test assertion error messages + val subTestSchemaFieldNames = Seq( + "myboolean:Boolean", + "mylong:Long" + ) + + val testFile = getTempFilePath("testParquetFile").getCanonicalFile + + lazy val testData = new ParquetRelation("testData", testFile.toURI.toString) + + def writeFile = { + testFile.delete + val path: Path = new Path(testFile.toURI) + val job = new Job() + val configuration: Configuration = ContextUtil.getConfiguration(job) + val schema: MessageType = MessageTypeParser.parseMessageType(testSchema) + + val writeSupport = new RowWriteSupport() + writeSupport.setSchema(schema, configuration) + val writer = new ParquetWriter(path, writeSupport) + for(i <- 0 until 15) { + val data = new Array[Any](6) + if (i % 3 == 0) { + data.update(0, true) + } else { + data.update(0, false) + } + if (i % 5 == 0) { + data.update(1, 5) + } else { + data.update(1, null) // optional + } + data.update(2, "abc") + data.update(3, i.toLong << 33) + data.update(4, 2.5F) + data.update(5, 4.5D) + writer.write(new GenericRow(data.toArray)) + } + writer.close() + } +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala new file mode 100644 index 000000000..ca56c4476 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark +package sql +package test + +/** A SQLContext that can be used for local testing. */ +object TestSQLContext + extends SQLContext(new SparkContext("local", "TestSQLContext", new SparkConf())) diff --git a/sql/core/src/test/resources/log4j.properties b/sql/core/src/test/resources/log4j.properties new file mode 100644 index 000000000..7bb6789bd --- /dev/null +++ b/sql/core/src/test/resources/log4j.properties @@ -0,0 +1,52 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Set everything to be logged to the file core/target/unit-tests.log +log4j.rootLogger=DEBUG, CA, FA + +#Console Appender +log4j.appender.CA=org.apache.log4j.ConsoleAppender +log4j.appender.CA.layout=org.apache.log4j.PatternLayout +log4j.appender.CA.layout.ConversionPattern=%d{HH:mm:ss.SSS} %p %c: %m%n +log4j.appender.CA.Threshold = WARN + + +#File Appender +log4j.appender.FA=org.apache.log4j.FileAppender +log4j.appender.FA.append=false +log4j.appender.FA.file=target/unit-tests.log +log4j.appender.FA.layout=org.apache.log4j.PatternLayout +log4j.appender.FA.layout.ConversionPattern=%d{HH:mm:ss.SSS} %p %c{1}: %m%n + +# Set the logger level of File Appender to WARN +log4j.appender.FA.Threshold = INFO + +# Some packages are noisy for no good reason. +log4j.additivity.org.apache.hadoop.hive.serde2.lazy.LazyStruct=false +log4j.logger.org.apache.hadoop.hive.serde2.lazy.LazyStruct=OFF + +log4j.additivity.org.apache.hadoop.hive.metastore.RetryingHMSHandler=false +log4j.logger.org.apache.hadoop.hive.metastore.RetryingHMSHandler=OFF + +log4j.additivity.hive.ql.metadata.Hive=false +log4j.logger.hive.ql.metadata.Hive=OFF + +# Parquet logging +parquet.hadoop.InternalParquetRecordReader=WARN +log4j.logger.parquet.hadoop.InternalParquetRecordReader=WARN +parquet.hadoop.ParquetInputFormat=WARN +log4j.logger.parquet.hadoop.ParquetInputFormat=WARN diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala new file mode 100644 index 000000000..37c90a18a --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala @@ -0,0 +1,201 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.scalatest.{BeforeAndAfterAll, FunSuite} + +import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.test._ + +/* Implicits */ +import TestSQLContext._ + +class DslQuerySuite extends QueryTest { + import TestData._ + + test("table scan") { + checkAnswer( + testData, + testData.collect().toSeq) + } + + test("agg") { + checkAnswer( + testData2.groupBy('a)('a, Sum('b)), + Seq((1,3),(2,3),(3,3)) + ) + } + + test("select *") { + checkAnswer( + testData.select(Star(None)), + testData.collect().toSeq) + } + + test("simple select") { + checkAnswer( + testData.where('key === 1).select('value), + Seq(Seq("1"))) + } + + test("sorting") { + checkAnswer( + testData2.orderBy('a.asc, 'b.asc), + Seq((1,1), (1,2), (2,1), (2,2), (3,1), (3,2))) + + checkAnswer( + testData2.orderBy('a.asc, 'b.desc), + Seq((1,2), (1,1), (2,2), (2,1), (3,2), (3,1))) + + checkAnswer( + testData2.orderBy('a.desc, 'b.desc), + Seq((3,2), (3,1), (2,2), (2,1), (1,2), (1,1))) + + checkAnswer( + testData2.orderBy('a.desc, 'b.asc), + Seq((3,1), (3,2), (2,1), (2,2), (1,1), (1,2))) + } + + test("average") { + checkAnswer( + testData2.groupBy()(Average('a)), + 2.0) + } + + test("count") { + checkAnswer( + testData2.groupBy()(Count(1)), + testData2.count() + ) + } + + test("null count") { + checkAnswer( + testData3.groupBy('a)('a, Count('b)), + Seq((1,0), (2, 1)) + ) + + checkAnswer( + testData3.groupBy()(Count('a), Count('b), Count(1), CountDistinct('a :: Nil), CountDistinct('b :: Nil)), + (2, 1, 2, 2, 1) :: Nil + ) + } + + test("inner join where, one match per row") { + checkAnswer( + upperCaseData.join(lowerCaseData, Inner).where('n === 'N), + Seq( + (1, "A", 1, "a"), + (2, "B", 2, "b"), + (3, "C", 3, "c"), + (4, "D", 4, "d") + )) + } + + test("inner join ON, one match per row") { + checkAnswer( + upperCaseData.join(lowerCaseData, Inner, Some('n === 'N)), + Seq( + (1, "A", 1, "a"), + (2, "B", 2, "b"), + (3, "C", 3, "c"), + (4, "D", 4, "d") + )) + } + + test("inner join, where, multiple matches") { + val x = testData2.where('a === 1).subquery('x) + val y = testData2.where('a === 1).subquery('y) + checkAnswer( + x.join(y).where("x.a".attr === "y.a".attr), + (1,1,1,1) :: + (1,1,1,2) :: + (1,2,1,1) :: + (1,2,1,2) :: Nil + ) + } + + test("inner join, no matches") { + val x = testData2.where('a === 1).subquery('x) + val y = testData2.where('a === 2).subquery('y) + checkAnswer( + x.join(y).where("x.a".attr === "y.a".attr), + Nil) + } + + test("big inner join, 4 matches per row") { + val bigData = testData.unionAll(testData).unionAll(testData).unionAll(testData) + val bigDataX = bigData.subquery('x) + val bigDataY = bigData.subquery('y) + + checkAnswer( + bigDataX.join(bigDataY).where("x.key".attr === "y.key".attr), + testData.flatMap( + row => Seq.fill(16)((row ++ row).toSeq)).collect().toSeq) + } + + test("cartisian product join") { + checkAnswer( + testData3.join(testData3), + (1, null, 1, null) :: + (1, null, 2, 2) :: + (2, 2, 1, null) :: + (2, 2, 2, 2) :: Nil) + } + + test("left outer join") { + checkAnswer( + upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N)), + (1, "A", 1, "a") :: + (2, "B", 2, "b") :: + (3, "C", 3, "c") :: + (4, "D", 4, "d") :: + (5, "E", null, null) :: + (6, "F", null, null) :: Nil) + } + + test("right outer join") { + checkAnswer( + lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N)), + (1, "a", 1, "A") :: + (2, "b", 2, "B") :: + (3, "c", 3, "C") :: + (4, "d", 4, "D") :: + (null, null, 5, "E") :: + (null, null, 6, "F") :: Nil) + } + + test("full outer join") { + val left = upperCaseData.where('N <= 4).subquery('left) + val right = upperCaseData.where('N >= 3).subquery('right) + + checkAnswer( + left.join(right, FullOuter, Some("left.N".attr === "right.N".attr)), + (1, "A", null, null) :: + (2, "B", null, null) :: + (3, "C", 3, "C") :: + (4, "D", 4, "D") :: + (null, null, 5, "E") :: + (null, null, 6, "F") :: Nil) + } +} \ No newline at end of file diff --git a/sql/core/src/test/scala/org/apache/spark/sql/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/PlannerSuite.scala new file mode 100644 index 000000000..83908edf5 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/PlannerSuite.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package execution + +import org.scalatest.FunSuite + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.TestData._ +import org.apache.spark.sql.test.TestSQLContext._ +import org.apache.spark.sql.test.TestSQLContext.planner._ + +class PlannerSuite extends FunSuite { + + + test("unions are collapsed") { + val query = testData.unionAll(testData).unionAll(testData).logicalPlan + val planned = BasicOperators(query).head + val logicalUnions = query collect { case u: logical.Union => u} + val physicalUnions = planned collect { case u: execution.Union => u} + + assert(logicalUnions.size === 2) + assert(physicalUnions.size === 1) + } + + test("count is partially aggregated") { + val query = testData.groupBy('value)(Count('key)).analyze.logicalPlan + val planned = PartialAggregation(query).head + val aggregations = planned.collect { case a: Aggregate => a } + + assert(aggregations.size === 2) + } + + test("count distinct is not partially aggregated") { + val query = testData.groupBy('value)(CountDistinct('key :: Nil)).analyze.logicalPlan + val planned = PartialAggregation(query.logicalPlan) + assert(planned.isEmpty) + } + + test("mixed aggregates are not partially aggregated") { + val query = + testData.groupBy('value)(Count('value), CountDistinct('key :: Nil)).analyze.logicalPlan + val planned = PartialAggregation(query) + assert(planned.isEmpty) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala new file mode 100644 index 000000000..728fecede --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.scalatest.{BeforeAndAfterAll, FunSuite} + +import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.test._ + +/* Implicits */ +import TestSQLContext._ + +class QueryTest extends FunSuite { + /** + * Runs the plan and makes sure the answer matches the expected result. + * @param plan the query to be executed + * @param expectedAnswer the expected result, can either be an Any, Seq[Product], or Seq[ Seq[Any] ]. + */ + protected def checkAnswer(rdd: SchemaRDD, expectedAnswer: Any): Unit = { + val convertedAnswer = expectedAnswer match { + case s: Seq[_] if s.isEmpty => s + case s: Seq[_] if s.head.isInstanceOf[Product] && + !s.head.isInstanceOf[Seq[_]] => s.map(_.asInstanceOf[Product].productIterator.toIndexedSeq) + case s: Seq[_] => s + case singleItem => Seq(Seq(singleItem)) + } + + val isSorted = rdd.logicalPlan.collect { case s: logical.Sort => s}.nonEmpty + def prepareAnswer(answer: Seq[Any]) = if (!isSorted) answer.sortBy(_.toString) else answer + val sparkAnswer = try rdd.collect().toSeq catch { + case e: Exception => + fail( + s""" + |Exception thrown while executing query: + |${rdd.logicalPlan} + |== Exception == + |$e + """.stripMargin) + } + if(prepareAnswer(convertedAnswer) != prepareAnswer(sparkAnswer)) { + fail(s""" + |Results do not match for query: + |${rdd.logicalPlan} + |== Analyzed Plan == + |${rdd.queryExecution.analyzed} + |== RDD == + |$rdd + |== Results == + |${sideBySide( + prepareAnswer(convertedAnswer).map(_.toString), + prepareAnswer(sparkAnswer).map(_.toString)).mkString("\n")} + """.stripMargin) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala new file mode 100644 index 000000000..5728313d6 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -0,0 +1,211 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.scalatest.{BeforeAndAfterAll, FunSuite} + +import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.test._ + +/* Implicits */ +import TestSQLContext._ +import TestData._ + +class SQLQuerySuite extends QueryTest { + test("agg") { + checkAnswer( + sql("SELECT a, SUM(b) FROM testData2 GROUP BY a"), + Seq((1,3),(2,3),(3,3)) + ) + } + + test("select *") { + checkAnswer( + sql("SELECT * FROM testData"), + testData.collect().toSeq) + } + + test("simple select") { + checkAnswer( + sql("SELECT value FROM testData WHERE key = 1"), + Seq(Seq("1"))) + } + + test("sorting") { + checkAnswer( + sql("SELECT * FROM testData2 ORDER BY a ASC, b ASC"), + Seq((1,1), (1,2), (2,1), (2,2), (3,1), (3,2))) + + checkAnswer( + sql("SELECT * FROM testData2 ORDER BY a ASC, b DESC"), + Seq((1,2), (1,1), (2,2), (2,1), (3,2), (3,1))) + + checkAnswer( + sql("SELECT * FROM testData2 ORDER BY a DESC, b DESC"), + Seq((3,2), (3,1), (2,2), (2,1), (1,2), (1,1))) + + checkAnswer( + sql("SELECT * FROM testData2 ORDER BY a DESC, b ASC"), + Seq((3,1), (3,2), (2,1), (2,2), (1,1), (1,2))) + } + + test("average") { + checkAnswer( + sql("SELECT AVG(a) FROM testData2"), + 2.0) + } + + test("count") { + checkAnswer( + sql("SELECT COUNT(*) FROM testData2"), + testData2.count() + ) + } + + // No support for primitive nulls yet. + ignore("null count") { + checkAnswer( + sql("SELECT a, COUNT(b) FROM testData3"), + Seq((1,0), (2, 1)) + ) + + checkAnswer( + testData3.groupBy()(Count('a), Count('b), Count(1), CountDistinct('a :: Nil), CountDistinct('b :: Nil)), + (2, 1, 2, 2, 1) :: Nil + ) + } + + test("inner join where, one match per row") { + checkAnswer( + sql("SELECT * FROM upperCaseData JOIN lowerCaseData WHERE n = N"), + Seq( + (1, "A", 1, "a"), + (2, "B", 2, "b"), + (3, "C", 3, "c"), + (4, "D", 4, "d") + )) + } + + test("inner join ON, one match per row") { + checkAnswer( + sql("SELECT * FROM upperCaseData JOIN lowerCaseData ON n = N"), + Seq( + (1, "A", 1, "a"), + (2, "B", 2, "b"), + (3, "C", 3, "c"), + (4, "D", 4, "d") + )) + } + + test("inner join, where, multiple matches") { + checkAnswer( + sql(""" + |SELECT * FROM + | (SELECT * FROM testData2 WHERE a = 1) x JOIN + | (SELECT * FROM testData2 WHERE a = 1) y + |WHERE x.a = y.a""".stripMargin), + (1,1,1,1) :: + (1,1,1,2) :: + (1,2,1,1) :: + (1,2,1,2) :: Nil + ) + } + + test("inner join, no matches") { + checkAnswer( + sql( + """ + |SELECT * FROM + | (SELECT * FROM testData2 WHERE a = 1) x JOIN + | (SELECT * FROM testData2 WHERE a = 2) y + |WHERE x.a = y.a""".stripMargin), + Nil) + } + + test("big inner join, 4 matches per row") { + + + checkAnswer( + sql( + """ + |SELECT * FROM + | (SELECT * FROM testData UNION ALL + | SELECT * FROM testData UNION ALL + | SELECT * FROM testData UNION ALL + | SELECT * FROM testData) x JOIN + | (SELECT * FROM testData UNION ALL + | SELECT * FROM testData UNION ALL + | SELECT * FROM testData UNION ALL + | SELECT * FROM testData) y + |WHERE x.key = y.key""".stripMargin), + testData.flatMap( + row => Seq.fill(16)((row ++ row).toSeq)).collect().toSeq) + } + + ignore("cartisian product join") { + checkAnswer( + testData3.join(testData3), + (1, null, 1, null) :: + (1, null, 2, 2) :: + (2, 2, 1, null) :: + (2, 2, 2, 2) :: Nil) + } + + test("left outer join") { + checkAnswer( + sql("SELECT * FROM upperCaseData LEFT OUTER JOIN lowerCaseData ON n = N"), + (1, "A", 1, "a") :: + (2, "B", 2, "b") :: + (3, "C", 3, "c") :: + (4, "D", 4, "d") :: + (5, "E", null, null) :: + (6, "F", null, null) :: Nil) + } + + test("right outer join") { + checkAnswer( + sql("SELECT * FROM lowerCaseData RIGHT OUTER JOIN upperCaseData ON n = N"), + (1, "a", 1, "A") :: + (2, "b", 2, "B") :: + (3, "c", 3, "C") :: + (4, "d", 4, "D") :: + (null, null, 5, "E") :: + (null, null, 6, "F") :: Nil) + } + + test("full outer join") { + checkAnswer( + sql( + """ + |SELECT * FROM + | (SELECT * FROM upperCaseData WHERE N <= 4) left FULL OUTER JOIN + | (SELECT * FROM upperCaseData WHERE N >= 3) right + | ON left.N = right.N + """.stripMargin), + (1, "A", null, null) :: + (2, "B", null, null) :: + (3, "C", 3, "C") :: + (4, "D", 4, "D") :: + (null, null, 5, "E") :: + (null, null, 6, "F") :: Nil) + } +} \ No newline at end of file diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala new file mode 100644 index 000000000..640292571 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.test._ + +/* Implicits */ +import TestSQLContext._ + +object TestData { + case class TestData(key: Int, value: String) + val testData: SchemaRDD = TestSQLContext.sparkContext.parallelize( + (1 to 100).map(i => TestData(i, i.toString))) + testData.registerAsTable("testData") + + case class TestData2(a: Int, b: Int) + val testData2: SchemaRDD = + TestSQLContext.sparkContext.parallelize( + TestData2(1, 1) :: + TestData2(1, 2) :: + TestData2(2, 1) :: + TestData2(2, 2) :: + TestData2(3, 1) :: + TestData2(3, 2) :: Nil + ) + testData2.registerAsTable("testData2") + + // TODO: There is no way to express null primitives as case classes currently... + val testData3 = + logical.LocalRelation('a.int, 'b.int).loadData( + (1, null) :: + (2, 2) :: Nil + ) + + case class UpperCaseData(N: Int, L: String) + val upperCaseData = + TestSQLContext.sparkContext.parallelize( + UpperCaseData(1, "A") :: + UpperCaseData(2, "B") :: + UpperCaseData(3, "C") :: + UpperCaseData(4, "D") :: + UpperCaseData(5, "E") :: + UpperCaseData(6, "F") :: Nil + ) + upperCaseData.registerAsTable("upperCaseData") + + case class LowerCaseData(n: Int, l: String) + val lowerCaseData = + TestSQLContext.sparkContext.parallelize( + LowerCaseData(1, "a") :: + LowerCaseData(2, "b") :: + LowerCaseData(3, "c") :: + LowerCaseData(4, "d") :: Nil + ) + lowerCaseData.registerAsTable("lowerCaseData") +} \ No newline at end of file diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TgfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TgfSuite.scala new file mode 100644 index 000000000..08265b7a6 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/TgfSuite.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package execution + +import org.scalatest.{BeforeAndAfterAll, FunSuite} + +import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.test._ + + +import TestSQLContext._ + +/** + * This is an example TGF that uses UnresolvedAttributes 'name and 'age to access specific columns + * from the input data. These will be replaced during analysis with specific AttributeReferences + * and then bound to specific ordinals during query planning. While TGFs could also access specific + * columns using hand-coded ordinals, doing so violates data independence. + * + * Note: this is only a rough example of how TGFs can be expressed, the final version will likely + * involve a lot more sugar for cleaner use in Scala/Java/etc. + */ +case class ExampleTGF(input: Seq[Attribute] = Seq('name, 'age)) extends Generator { + def children = input + protected def makeOutput() = 'nameAndAge.string :: Nil + + val Seq(nameAttr, ageAttr) = input + + override def apply(input: Row): TraversableOnce[Row] = { + val name = nameAttr.apply(input) + val age = ageAttr.apply(input).asInstanceOf[Int] + + Iterator( + new GenericRow(Array[Any](s"$name is $age years old")), + new GenericRow(Array[Any](s"Next year, $name will be ${age + 1} years old"))) + } +} + +class TgfSuite extends QueryTest { + val inputData = + logical.LocalRelation('name.string, 'age.int).loadData( + ("michael", 29) :: Nil + ) + + test("simple tgf example") { + checkAnswer( + inputData.generate(ExampleTGF()), + Seq( + "michael is 29 years old" :: Nil, + "Next year, michael will be 30 years old" :: Nil)) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala new file mode 100644 index 000000000..8b2ccb52d --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.parquet + +import org.scalatest.{BeforeAndAfterAll, FunSuite} + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions.Row +import org.apache.spark.sql.catalyst.util.getTempFilePath +import org.apache.spark.sql.test.TestSQLContext + +import org.apache.hadoop.mapreduce.Job +import org.apache.hadoop.fs.{Path, FileSystem} + +import parquet.schema.MessageTypeParser +import parquet.hadoop.ParquetFileWriter +import parquet.hadoop.util.ContextUtil + +class ParquetQuerySuite extends FunSuite with BeforeAndAfterAll { + override def beforeAll() { + ParquetTestData.writeFile + } + + override def afterAll() { + ParquetTestData.testFile.delete() + } + + test("Import of simple Parquet file") { + val result = getRDD(ParquetTestData.testData).collect() + assert(result.size === 15) + result.zipWithIndex.foreach { + case (row, index) => { + val checkBoolean = + if (index % 3 == 0) + row(0) == true + else + row(0) == false + assert(checkBoolean === true, s"boolean field value in line $index did not match") + if (index % 5 == 0) assert(row(1) === 5, s"int field value in line $index did not match") + assert(row(2) === "abc", s"string field value in line $index did not match") + assert(row(3) === (index.toLong << 33), s"long value in line $index did not match") + assert(row(4) === 2.5F, s"float field value in line $index did not match") + assert(row(5) === 4.5D, s"double field value in line $index did not match") + } + } + } + + test("Projection of simple Parquet file") { + val scanner = new ParquetTableScan( + ParquetTestData.testData.output, + ParquetTestData.testData, + None)(TestSQLContext.sparkContext) + val projected = scanner.pruneColumns(ParquetTypesConverter + .convertToAttributes(MessageTypeParser + .parseMessageType(ParquetTestData.subTestSchema))) + assert(projected.output.size === 2) + val result = projected + .execute() + .map(_.copy()) + .collect() + result.zipWithIndex.foreach { + case (row, index) => { + if (index % 3 == 0) + assert(row(0) === true, s"boolean field value in line $index did not match (every third row)") + else + assert(row(0) === false, s"boolean field value in line $index did not match") + assert(row(1) === (index.toLong << 33), s"long field value in line $index did not match") + assert(row.size === 2, s"number of columns in projection in line $index is incorrect") + } + } + } + + test("Writing metadata from scratch for table CREATE") { + val job = new Job() + val path = new Path(getTempFilePath("testtable").getCanonicalFile.toURI.toString) + val fs: FileSystem = FileSystem.getLocal(ContextUtil.getConfiguration(job)) + ParquetTypesConverter.writeMetaData( + ParquetTestData.testData.output, + path, + TestSQLContext.sparkContext.hadoopConfiguration) + assert(fs.exists(new Path(path, ParquetFileWriter.PARQUET_METADATA_FILE))) + val metaData = ParquetTypesConverter.readMetaData(path) + assert(metaData != null) + ParquetTestData + .testData + .parquetSchema + .checkContains(metaData.getFileMetaData.getSchema) // throws exception if incompatible + metaData + .getFileMetaData + .getSchema + .checkContains(ParquetTestData.testData.parquetSchema) // throws exception if incompatible + fs.delete(path, true) + } + + /** + * Computes the given [[ParquetRelation]] and returns its RDD. + * + * @param parquetRelation The Parquet relation. + * @return An RDD of Rows. + */ + private def getRDD(parquetRelation: ParquetRelation): RDD[Row] = { + val scanner = new ParquetTableScan( + parquetRelation.output, + parquetRelation, + None)(TestSQLContext.sparkContext) + scanner + .execute + .map(_.copy()) + } +} + diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml new file mode 100644 index 000000000..7b5ea98f2 --- /dev/null +++ b/sql/hive/pom.xml @@ -0,0 +1,81 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent + 1.0.0-SNAPSHOT + ../../pom.xml + + + org.apache.spark + spark-hive_2.10 + jar + Spark Project Hive + http://spark.apache.org/ + + + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + + + org.apache.spark + spark-sql_${scala.binary.version} + ${project.version} + + + org.apache.hive + hive-metastore + ${hive.version} + + + org.apache.hive + hive-exec + ${hive.version} + + + org.apache.hive + hive-serde + ${hive.version} + + + org.scalatest + scalatest_${scala.binary.version} + test + + + org.scalacheck + scalacheck_${scala.binary.version} + test + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + + org.scalatest + scalatest-maven-plugin + + + + diff --git a/sql/hive/src/main/scala/org/apache/hadoop/mapred/SparkHadoopWriter.scala b/sql/hive/src/main/scala/org/apache/hadoop/mapred/SparkHadoopWriter.scala new file mode 100644 index 000000000..08d390e88 --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/hadoop/mapred/SparkHadoopWriter.scala @@ -0,0 +1,198 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.mapred + +import java.io.IOException +import java.text.NumberFormat +import java.util.Date + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.io.Writable + +import org.apache.spark.Logging +import org.apache.spark.SerializableWritable + +import org.apache.hadoop.hive.ql.exec.{Utilities, FileSinkOperator} +import org.apache.hadoop.hive.ql.io.{HiveFileFormatUtils, HiveOutputFormat} +import org.apache.hadoop.hive.ql.plan.FileSinkDesc + +/** + * Internal helper class that saves an RDD using a Hive OutputFormat. + * It is based on [[SparkHadoopWriter]]. + */ +protected[apache] +class SparkHiveHadoopWriter( + @transient jobConf: JobConf, + fileSinkConf: FileSinkDesc) + extends Logging + with SparkHadoopMapRedUtil + with Serializable { + + private val now = new Date() + private val conf = new SerializableWritable(jobConf) + + private var jobID = 0 + private var splitID = 0 + private var attemptID = 0 + private var jID: SerializableWritable[JobID] = null + private var taID: SerializableWritable[TaskAttemptID] = null + + @transient private var writer: FileSinkOperator.RecordWriter = null + @transient private var format: HiveOutputFormat[AnyRef, Writable] = null + @transient private var committer: OutputCommitter = null + @transient private var jobContext: JobContext = null + @transient private var taskContext: TaskAttemptContext = null + + def preSetup() { + setIDs(0, 0, 0) + setConfParams() + + val jCtxt = getJobContext() + getOutputCommitter().setupJob(jCtxt) + } + + + def setup(jobid: Int, splitid: Int, attemptid: Int) { + setIDs(jobid, splitid, attemptid) + setConfParams() + } + + def open() { + val numfmt = NumberFormat.getInstance() + numfmt.setMinimumIntegerDigits(5) + numfmt.setGroupingUsed(false) + + val extension = Utilities.getFileExtension( + conf.value, + fileSinkConf.getCompressed, + getOutputFormat()) + + val outputName = "part-" + numfmt.format(splitID) + extension + val path = FileOutputFormat.getTaskOutputPath(conf.value, outputName) + + getOutputCommitter().setupTask(getTaskContext()) + writer = HiveFileFormatUtils.getHiveRecordWriter( + conf.value, + fileSinkConf.getTableInfo, + conf.value.getOutputValueClass.asInstanceOf[Class[Writable]], + fileSinkConf, + path, + null) + } + + def write(value: Writable) { + if (writer != null) { + writer.write(value) + } else { + throw new IOException("Writer is null, open() has not been called") + } + } + + def close() { + // Seems the boolean value passed into close does not matter. + writer.close(false) + } + + def commit() { + val taCtxt = getTaskContext() + val cmtr = getOutputCommitter() + if (cmtr.needsTaskCommit(taCtxt)) { + try { + cmtr.commitTask(taCtxt) + logInfo (taID + ": Committed") + } catch { + case e: IOException => { + logError("Error committing the output of task: " + taID.value, e) + cmtr.abortTask(taCtxt) + throw e + } + } + } else { + logWarning ("No need to commit output of task: " + taID.value) + } + } + + def commitJob() { + // always ? Or if cmtr.needsTaskCommit ? + val cmtr = getOutputCommitter() + cmtr.commitJob(getJobContext()) + } + + // ********* Private Functions ********* + + private def getOutputFormat(): HiveOutputFormat[AnyRef,Writable] = { + if (format == null) { + format = conf.value.getOutputFormat() + .asInstanceOf[HiveOutputFormat[AnyRef,Writable]] + } + format + } + + private def getOutputCommitter(): OutputCommitter = { + if (committer == null) { + committer = conf.value.getOutputCommitter + } + committer + } + + private def getJobContext(): JobContext = { + if (jobContext == null) { + jobContext = newJobContext(conf.value, jID.value) + } + jobContext + } + + private def getTaskContext(): TaskAttemptContext = { + if (taskContext == null) { + taskContext = newTaskAttemptContext(conf.value, taID.value) + } + taskContext + } + + private def setIDs(jobid: Int, splitid: Int, attemptid: Int) { + jobID = jobid + splitID = splitid + attemptID = attemptid + + jID = new SerializableWritable[JobID](SparkHadoopWriter.createJobID(now, jobid)) + taID = new SerializableWritable[TaskAttemptID]( + new TaskAttemptID(new TaskID(jID.value, true, splitID), attemptID)) + } + + private def setConfParams() { + conf.value.set("mapred.job.id", jID.value.toString) + conf.value.set("mapred.tip.id", taID.value.getTaskID.toString) + conf.value.set("mapred.task.id", taID.value.toString) + conf.value.setBoolean("mapred.task.is.map", true) + conf.value.setInt("mapred.task.partition", splitID) + } +} + +object SparkHiveHadoopWriter { + def createPathFromString(path: String, conf: JobConf): Path = { + if (path == null) { + throw new IllegalArgumentException("Output path is null") + } + val outputPath = new Path(path) + val fs = outputPath.getFileSystem(conf) + if (outputPath == null || fs == null) { + throw new IllegalArgumentException("Incorrectly formatted output path") + } + outputPath.makeQualified(fs) + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala new file mode 100644 index 000000000..4aad876cc --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -0,0 +1,287 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package hive + +import java.io.{PrintStream, InputStreamReader, BufferedReader, File} +import java.util.{ArrayList => JArrayList} +import scala.language.implicitConversions + +import org.apache.spark.SparkContext +import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.ql.session.SessionState +import org.apache.hadoop.hive.ql.processors.{CommandProcessorResponse, CommandProcessorFactory} +import org.apache.hadoop.hive.ql.processors.CommandProcessor +import org.apache.hadoop.hive.ql.Driver +import org.apache.spark.rdd.RDD + +import catalyst.analysis.{Analyzer, OverrideCatalog} +import catalyst.expressions.GenericRow +import catalyst.plans.logical.{BaseRelation, LogicalPlan, NativeCommand, ExplainCommand} +import catalyst.types._ + +import org.apache.spark.sql.execution._ + +import scala.collection.JavaConversions._ + +/** + * Starts up an instance of hive where metadata is stored locally. An in-process metadata data is + * created with data stored in ./metadata. Warehouse data is stored in in ./warehouse. + */ +class LocalHiveContext(sc: SparkContext) extends HiveContext(sc) { + + lazy val metastorePath = new File("metastore").getCanonicalPath + lazy val warehousePath: String = new File("warehouse").getCanonicalPath + + /** Sets up the system initially or after a RESET command */ + protected def configure() { + // TODO: refactor this so we can work with other databases. + runSqlHive( + s"set javax.jdo.option.ConnectionURL=jdbc:derby:;databaseName=$metastorePath;create=true") + runSqlHive("set hive.metastore.warehouse.dir=" + warehousePath) + } + + configure() // Must be called before initializing the catalog below. +} + +/** + * An instance of the Spark SQL execution engine that integrates with data stored in Hive. + * Configuration for Hive is read from hive-site.xml on the classpath. + */ +class HiveContext(sc: SparkContext) extends SQLContext(sc) { + self => + + override def parseSql(sql: String): LogicalPlan = HiveQl.parseSql(sql) + override def executePlan(plan: LogicalPlan): this.QueryExecution = + new this.QueryExecution { val logical = plan } + + // Circular buffer to hold what hive prints to STDOUT and ERR. Only printed when failures occur. + @transient + protected val outputBuffer = new java.io.OutputStream { + var pos: Int = 0 + var buffer = new Array[Int](10240) + def write(i: Int): Unit = { + buffer(pos) = i + pos = (pos + 1) % buffer.size + } + + override def toString = { + val (end, start) = buffer.splitAt(pos) + val input = new java.io.InputStream { + val iterator = (start ++ end).iterator + + def read(): Int = if (iterator.hasNext) iterator.next else -1 + } + val reader = new BufferedReader(new InputStreamReader(input)) + val stringBuilder = new StringBuilder + var line = reader.readLine() + while(line != null) { + stringBuilder.append(line) + stringBuilder.append("\n") + line = reader.readLine() + } + stringBuilder.toString() + } + } + + @transient protected[hive] lazy val hiveconf = new HiveConf(classOf[SessionState]) + @transient protected[hive] lazy val sessionState = new SessionState(hiveconf) + + sessionState.err = new PrintStream(outputBuffer, true, "UTF-8") + sessionState.out = new PrintStream(outputBuffer, true, "UTF-8") + + /* A catalyst metadata catalog that points to the Hive Metastore. */ + @transient + override lazy val catalog = new HiveMetastoreCatalog(this) with OverrideCatalog + + /* An analyzer that uses the Hive metastore. */ + @transient + override lazy val analyzer = new Analyzer(catalog, HiveFunctionRegistry, caseSensitive = false) + + def tables: Seq[BaseRelation] = { + // TODO: Move this functionallity to Catalog. Make client protected. + val allTables = catalog.client.getAllTables("default") + allTables.map(catalog.lookupRelation(None, _, None)).collect { case b: BaseRelation => b } + } + + /** + * Runs the specified SQL query using Hive. + */ + protected def runSqlHive(sql: String): Seq[String] = { + val maxResults = 100000 + val results = runHive(sql, 100000) + // It is very confusing when you only get back some of the results... + if (results.size == maxResults) sys.error("RESULTS POSSIBLY TRUNCATED") + results + } + + // TODO: Move this. + + SessionState.start(sessionState) + + /** + * Execute the command using Hive and return the results as a sequence. Each element + * in the sequence is one row. + */ + protected def runHive(cmd: String, maxRows: Int = 1000): Seq[String] = { + try { + val cmd_trimmed: String = cmd.trim() + val tokens: Array[String] = cmd_trimmed.split("\\s+") + val cmd_1: String = cmd_trimmed.substring(tokens(0).length()).trim() + val proc: CommandProcessor = CommandProcessorFactory.get(tokens(0), hiveconf) + + SessionState.start(sessionState) + + if (proc.isInstanceOf[Driver]) { + val driver: Driver = proc.asInstanceOf[Driver] + driver.init() + + val results = new JArrayList[String] + val response: CommandProcessorResponse = driver.run(cmd) + // Throw an exception if there is an error in query processing. + if (response.getResponseCode != 0) { + driver.destroy() + throw new QueryExecutionException(response.getErrorMessage) + } + driver.setMaxRows(maxRows) + driver.getResults(results) + driver.destroy() + results + } else { + sessionState.out.println(tokens(0) + " " + cmd_1) + Seq(proc.run(cmd_1).getResponseCode.toString) + } + } catch { + case e: Exception => + logger.error( + s""" + |====================== + |HIVE FAILURE OUTPUT + |====================== + |${outputBuffer.toString} + |====================== + |END HIVE FAILURE OUTPUT + |====================== + """.stripMargin) + throw e + } + } + + @transient + val hivePlanner = new SparkPlanner with HiveStrategies { + val hiveContext = self + + override val strategies: Seq[Strategy] = Seq( + TopK, + ColumnPrunings, + PartitionPrunings, + HiveTableScans, + DataSinks, + Scripts, + PartialAggregation, + SparkEquiInnerJoin, + BasicOperators, + CartesianProduct, + BroadcastNestedLoopJoin + ) + } + + @transient + override val planner = hivePlanner + + @transient + protected lazy val emptyResult = + sparkContext.parallelize(Seq(new GenericRow(Array[Any]()): Row), 1) + + /** Extends QueryExecution with hive specific features. */ + abstract class QueryExecution extends super.QueryExecution { + // TODO: Create mixin for the analyzer instead of overriding things here. + override lazy val optimizedPlan = + optimizer(catalog.PreInsertionCasts(catalog.CreateTables(analyzed))) + + // TODO: We are loosing schema here. + override lazy val toRdd: RDD[Row] = + analyzed match { + case NativeCommand(cmd) => + val output = runSqlHive(cmd) + + if (output.size == 0) { + emptyResult + } else { + val asRows = output.map(r => new GenericRow(r.split("\t").asInstanceOf[Array[Any]])) + sparkContext.parallelize(asRows, 1) + } + case _ => + executedPlan.execute.map(_.copy()) + } + + protected val primitiveTypes = + Seq(StringType, IntegerType, LongType, DoubleType, FloatType, BooleanType, ByteType, + ShortType, DecimalType) + + protected def toHiveString(a: (Any, DataType)): String = a match { + case (struct: Row, StructType(fields)) => + struct.zip(fields).map { + case (v, t) => s""""${t.name}":${toHiveStructString(v, t.dataType)}""" + }.mkString("{", ",", "}") + case (seq: Seq[_], ArrayType(typ))=> + seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]") + case (map: Map[_,_], MapType(kType, vType)) => + map.map { + case (key, value) => + toHiveStructString((key, kType)) + ":" + toHiveStructString((value, vType)) + }.toSeq.sorted.mkString("{", ",", "}") + case (null, _) => "NULL" + case (other, tpe) if primitiveTypes contains tpe => other.toString + } + + /** Hive outputs fields of structs slightly differently than top level attributes. */ + protected def toHiveStructString(a: (Any, DataType)): String = a match { + case (struct: Row, StructType(fields)) => + struct.zip(fields).map { + case (v, t) => s""""${t.name}":${toHiveStructString(v, t.dataType)}""" + }.mkString("{", ",", "}") + case (seq: Seq[_], ArrayType(typ))=> + seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]") + case (map: Map[_,_], MapType(kType, vType)) => + map.map { + case (key, value) => + toHiveStructString((key, kType)) + ":" + toHiveStructString((value, vType)) + }.toSeq.sorted.mkString("{", ",", "}") + case (null, _) => "null" + case (s: String, StringType) => "\"" + s + "\"" + case (other, tpe) if primitiveTypes contains tpe => other.toString + } + + /** + * Returns the result as a hive compatible sequence of strings. For native commands, the + * execution is simply passed back to Hive. + */ + def stringResult(): Seq[String] = analyzed match { + case NativeCommand(cmd) => runSqlHive(cmd) + case ExplainCommand(plan) => new QueryExecution { val logical = plan }.toString.split("\n") + case query => + val result: Seq[Seq[Any]] = toRdd.collect().toSeq + // We need the types so we can output struct field names + val types = analyzed.output.map(_.dataType) + // Reformat to match hive tab delimited output. + val asString = result.map(_.zip(types).map(toHiveString)).map(_.mkString("\t")).toSeq + asString + } + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala new file mode 100644 index 000000000..e4d50722c --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -0,0 +1,246 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package hive + +import scala.util.parsing.combinator.RegexParsers + +import org.apache.hadoop.hive.metastore.api.{FieldSchema, StorageDescriptor, SerDeInfo} +import org.apache.hadoop.hive.metastore.api.{Table => TTable, Partition => TPartition} +import org.apache.hadoop.hive.ql.metadata.{Hive, Partition, Table} +import org.apache.hadoop.hive.ql.plan.TableDesc +import org.apache.hadoop.hive.ql.session.SessionState +import org.apache.hadoop.hive.serde2.Deserializer + +import catalyst.analysis.Catalog +import catalyst.expressions._ +import catalyst.plans.logical +import catalyst.plans.logical._ +import catalyst.rules._ +import catalyst.types._ + +import scala.collection.JavaConversions._ + +class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with Logging { + import HiveMetastoreTypes._ + + val client = Hive.get(hive.hiveconf) + + def lookupRelation( + db: Option[String], + tableName: String, + alias: Option[String]): LogicalPlan = { + val databaseName = db.getOrElse(hive.sessionState.getCurrentDatabase()) + val table = client.getTable(databaseName, tableName) + val partitions: Seq[Partition] = + if (table.isPartitioned) { + client.getPartitions(table) + } else { + Nil + } + + // Since HiveQL is case insensitive for table names we make them all lowercase. + MetastoreRelation( + databaseName.toLowerCase, + tableName.toLowerCase, + alias)(table.getTTable, partitions.map(part => part.getTPartition)) + } + + def createTable(databaseName: String, tableName: String, schema: Seq[Attribute]) { + val table = new Table(databaseName, tableName) + val hiveSchema = + schema.map(attr => new FieldSchema(attr.name, toMetastoreType(attr.dataType), "")) + table.setFields(hiveSchema) + + val sd = new StorageDescriptor() + table.getTTable.setSd(sd) + sd.setCols(hiveSchema) + + // TODO: THESE ARE ALL DEFAULTS, WE NEED TO PARSE / UNDERSTAND the output specs. + sd.setCompressed(false) + sd.setParameters(Map[String, String]()) + sd.setInputFormat("org.apache.hadoop.mapred.TextInputFormat") + sd.setOutputFormat("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat") + val serDeInfo = new SerDeInfo() + serDeInfo.setName(tableName) + serDeInfo.setSerializationLib("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe") + serDeInfo.setParameters(Map[String, String]()) + sd.setSerdeInfo(serDeInfo) + client.createTable(table) + } + + /** + * Creates any tables required for query execution. + * For example, because of a CREATE TABLE X AS statement. + */ + object CreateTables extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case InsertIntoCreatedTable(db, tableName, child) => + val databaseName = db.getOrElse(SessionState.get.getCurrentDatabase()) + + createTable(databaseName, tableName, child.output) + + InsertIntoTable( + lookupRelation(Some(databaseName), tableName, None).asInstanceOf[BaseRelation], + Map.empty, + child, + overwrite = false) + } + } + + /** + * Casts input data to correct data types according to table definition before inserting into + * that table. + */ + object PreInsertionCasts extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan.transform { + // Wait until children are resolved + case p: LogicalPlan if !p.childrenResolved => p + + case p @ InsertIntoTable(table: MetastoreRelation, _, child, _) => + val childOutputDataTypes = child.output.map(_.dataType) + // Only check attributes, not partitionKeys since they are always strings. + // TODO: Fully support inserting into partitioned tables. + val tableOutputDataTypes = table.attributes.map(_.dataType) + + if (childOutputDataTypes == tableOutputDataTypes) { + p + } else { + // Only do the casting when child output data types differ from table output data types. + val castedChildOutput = child.output.zip(table.output).map { + case (input, table) if input.dataType != table.dataType => + Alias(Cast(input, table.dataType), input.name)() + case (input, _) => input + } + + p.copy(child = logical.Project(castedChildOutput, child)) + } + } + } + + /** + * UNIMPLEMENTED: It needs to be decided how we will persist in-memory tables to the metastore. + * For now, if this functionallity is desired mix in the in-memory [[OverrideCatalog]]. + */ + override def registerTable( + databaseName: Option[String], tableName: String, plan: LogicalPlan): Unit = ??? +} + +object HiveMetastoreTypes extends RegexParsers { + protected lazy val primitiveType: Parser[DataType] = + "string" ^^^ StringType | + "float" ^^^ FloatType | + "int" ^^^ IntegerType | + "tinyint" ^^^ ShortType | + "double" ^^^ DoubleType | + "bigint" ^^^ LongType | + "binary" ^^^ BinaryType | + "boolean" ^^^ BooleanType | + "decimal" ^^^ DecimalType | + "varchar\\((\\d+)\\)".r ^^^ StringType + + protected lazy val arrayType: Parser[DataType] = + "array" ~> "<" ~> dataType <~ ">" ^^ ArrayType + + protected lazy val mapType: Parser[DataType] = + "map" ~> "<" ~> dataType ~ "," ~ dataType <~ ">" ^^ { + case t1 ~ _ ~ t2 => MapType(t1, t2) + } + + protected lazy val structField: Parser[StructField] = + "[a-zA-Z0-9]*".r ~ ":" ~ dataType ^^ { + case name ~ _ ~ tpe => StructField(name, tpe, nullable = true) + } + + protected lazy val structType: Parser[DataType] = + "struct" ~> "<" ~> repsep(structField,",") <~ ">" ^^ StructType + + protected lazy val dataType: Parser[DataType] = + arrayType | + mapType | + structType | + primitiveType + + def toDataType(metastoreType: String): DataType = parseAll(dataType, metastoreType) match { + case Success(result, _) => result + case failure: NoSuccess => sys.error(s"Unsupported dataType: $metastoreType") + } + + def toMetastoreType(dt: DataType): String = dt match { + case ArrayType(elementType) => s"array<${toMetastoreType(elementType)}>" + case StructType(fields) => + s"struct<${fields.map(f => s"${f.name}:${toMetastoreType(f.dataType)}").mkString(",")}>" + case MapType(keyType, valueType) => + s"map<${toMetastoreType(keyType)},${toMetastoreType(valueType)}>" + case StringType => "string" + case FloatType => "float" + case IntegerType => "int" + case ShortType =>"tinyint" + case DoubleType => "double" + case LongType => "bigint" + case BinaryType => "binary" + case BooleanType => "boolean" + case DecimalType => "decimal" + } +} + +case class MetastoreRelation(databaseName: String, tableName: String, alias: Option[String]) + (val table: TTable, val partitions: Seq[TPartition]) + extends BaseRelation { + // TODO: Can we use org.apache.hadoop.hive.ql.metadata.Table as the type of table and + // use org.apache.hadoop.hive.ql.metadata.Partition as the type of elements of partitions. + // Right now, using org.apache.hadoop.hive.ql.metadata.Table and + // org.apache.hadoop.hive.ql.metadata.Partition will cause a NotSerializableException + // which indicates the SerDe we used is not Serializable. + + def hiveQlTable = new Table(table) + + def hiveQlPartitions = partitions.map { p => + new Partition(hiveQlTable, p) + } + + override def isPartitioned = hiveQlTable.isPartitioned + + val tableDesc = new TableDesc( + Class.forName(hiveQlTable.getSerializationLib).asInstanceOf[Class[Deserializer]], + hiveQlTable.getInputFormatClass, + // The class of table should be org.apache.hadoop.hive.ql.metadata.Table because + // getOutputFormatClass will use HiveFileFormatUtils.getOutputFormatSubstitute to + // substitute some output formats, e.g. substituting SequenceFileOutputFormat to + // HiveSequenceFileOutputFormat. + hiveQlTable.getOutputFormatClass, + hiveQlTable.getMetadata + ) + + implicit class SchemaAttribute(f: FieldSchema) { + def toAttribute = AttributeReference( + f.getName, + HiveMetastoreTypes.toDataType(f.getType), + // Since data can be dumped in randomly with no validation, everything is nullable. + nullable = true + )(qualifiers = tableName +: alias.toSeq) + } + + // Must be a stable value since new attributes are born here. + val partitionKeys = hiveQlTable.getPartitionKeys.map(_.toAttribute) + + /** Non-partitionKey attributes */ + val attributes = table.getSd.getCols.map(_.toAttribute) + + val output = attributes ++ partitionKeys +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala new file mode 100644 index 000000000..4f33a293c --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -0,0 +1,966 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package hive + +import scala.collection.JavaConversions._ + +import org.apache.hadoop.hive.ql.lib.Node +import org.apache.hadoop.hive.ql.parse._ +import org.apache.hadoop.hive.ql.plan.PlanUtils + +import catalyst.analysis._ +import catalyst.expressions._ +import catalyst.plans._ +import catalyst.plans.logical +import catalyst.plans.logical._ +import catalyst.types._ + +/** + * Used when we need to start parsing the AST before deciding that we are going to pass the command + * back for Hive to execute natively. Will be replaced with a native command that contains the + * cmd string. + */ +case object NativePlaceholder extends Command + +case class DfsCommand(cmd: String) extends Command + +case class ShellCommand(cmd: String) extends Command + +case class SourceCommand(filePath: String) extends Command + +case class AddJar(jarPath: String) extends Command + +case class AddFile(filePath: String) extends Command + +/** Provides a mapping from HiveQL statments to catalyst logical plans and expression trees. */ +object HiveQl { + protected val nativeCommands = Seq( + "TOK_DESCFUNCTION", + "TOK_DESCTABLE", + "TOK_DESCDATABASE", + "TOK_SHOW_TABLESTATUS", + "TOK_SHOWDATABASES", + "TOK_SHOWFUNCTIONS", + "TOK_SHOWINDEXES", + "TOK_SHOWINDEXES", + "TOK_SHOWPARTITIONS", + "TOK_SHOWTABLES", + + "TOK_LOCKTABLE", + "TOK_SHOWLOCKS", + "TOK_UNLOCKTABLE", + + "TOK_CREATEROLE", + "TOK_DROPROLE", + "TOK_GRANT", + "TOK_GRANT_ROLE", + "TOK_REVOKE", + "TOK_SHOW_GRANT", + "TOK_SHOW_ROLE_GRANT", + + "TOK_CREATEFUNCTION", + "TOK_DROPFUNCTION", + + "TOK_ANALYZE", + "TOK_ALTERDATABASE_PROPERTIES", + "TOK_ALTERINDEX_PROPERTIES", + "TOK_ALTERINDEX_REBUILD", + "TOK_ALTERTABLE_ADDCOLS", + "TOK_ALTERTABLE_ADDPARTS", + "TOK_ALTERTABLE_ALTERPARTS", + "TOK_ALTERTABLE_ARCHIVE", + "TOK_ALTERTABLE_CLUSTER_SORT", + "TOK_ALTERTABLE_DROPPARTS", + "TOK_ALTERTABLE_PARTITION", + "TOK_ALTERTABLE_PROPERTIES", + "TOK_ALTERTABLE_RENAME", + "TOK_ALTERTABLE_RENAMECOL", + "TOK_ALTERTABLE_REPLACECOLS", + "TOK_ALTERTABLE_SKEWED", + "TOK_ALTERTABLE_TOUCH", + "TOK_ALTERTABLE_UNARCHIVE", + "TOK_ANALYZE", + "TOK_CREATEDATABASE", + "TOK_CREATEFUNCTION", + "TOK_CREATEINDEX", + "TOK_DROPDATABASE", + "TOK_DROPINDEX", + "TOK_DROPTABLE", + "TOK_MSCK", + + // TODO(marmbrus): Figure out how view are expanded by hive, as we might need to handle this. + "TOK_ALTERVIEW_ADDPARTS", + "TOK_ALTERVIEW_AS", + "TOK_ALTERVIEW_DROPPARTS", + "TOK_ALTERVIEW_PROPERTIES", + "TOK_ALTERVIEW_RENAME", + "TOK_CREATEVIEW", + "TOK_DROPVIEW", + + "TOK_EXPORT", + "TOK_IMPORT", + "TOK_LOAD", + + "TOK_SWITCHDATABASE" + ) + + /** + * A set of implicit transformations that allow Hive ASTNodes to be rewritten by transformations + * similar to [[catalyst.trees.TreeNode]]. + * + * Note that this should be considered very experimental and is not indented as a replacement + * for TreeNode. Primarily it should be noted ASTNodes are not immutable and do not appear to + * have clean copy semantics. Therefore, users of this class should take care when + * copying/modifying trees that might be used elsewhere. + */ + implicit class TransformableNode(n: ASTNode) { + /** + * Returns a copy of this node where `rule` has been recursively applied to it and all of its + * children. When `rule` does not apply to a given node it is left unchanged. + * @param rule the function use to transform this nodes children + */ + def transform(rule: PartialFunction[ASTNode, ASTNode]): ASTNode = { + try { + val afterRule = rule.applyOrElse(n, identity[ASTNode]) + afterRule.withChildren( + nilIfEmpty(afterRule.getChildren) + .asInstanceOf[Seq[ASTNode]] + .map(ast => Option(ast).map(_.transform(rule)).orNull)) + } catch { + case e: Exception => + println(dumpTree(n)) + throw e + } + } + + /** + * Returns a scala.Seq equivilent to [s] or Nil if [s] is null. + */ + private def nilIfEmpty[A](s: java.util.List[A]): Seq[A] = + Option(s).map(_.toSeq).getOrElse(Nil) + + /** + * Returns this ASTNode with the text changed to `newText``. + */ + def withText(newText: String): ASTNode = { + n.token.asInstanceOf[org.antlr.runtime.CommonToken].setText(newText) + n + } + + /** + * Returns this ASTNode with the children changed to `newChildren`. + */ + def withChildren(newChildren: Seq[ASTNode]): ASTNode = { + (1 to n.getChildCount).foreach(_ => n.deleteChild(0)) + n.addChildren(newChildren) + n + } + + /** + * Throws an error if this is not equal to other. + * + * Right now this function only checks the name, type, text and children of the node + * for equality. + */ + def checkEquals(other: ASTNode) { + def check(field: String, f: ASTNode => Any) = if (f(n) != f(other)) { + sys.error(s"$field does not match for trees. " + + s"'${f(n)}' != '${f(other)}' left: ${dumpTree(n)}, right: ${dumpTree(other)}") + } + check("name", _.getName) + check("type", _.getType) + check("text", _.getText) + check("numChildren", n => nilIfEmpty(n.getChildren).size) + + val leftChildren = nilIfEmpty(n.getChildren).asInstanceOf[Seq[ASTNode]] + val rightChildren = nilIfEmpty(other.getChildren).asInstanceOf[Seq[ASTNode]] + leftChildren zip rightChildren foreach { + case (l,r) => l checkEquals r + } + } + } + + class ParseException(sql: String, cause: Throwable) + extends Exception(s"Failed to parse: $sql", cause) + + /** + * Returns the AST for the given SQL string. + */ + def getAst(sql: String): ASTNode = ParseUtils.findRootNonNullToken((new ParseDriver).parse(sql)) + + /** Returns a LogicalPlan for a given HiveQL string. */ + def parseSql(sql: String): LogicalPlan = { + try { + if (sql.toLowerCase.startsWith("set")) { + NativeCommand(sql) + } else if (sql.toLowerCase.startsWith("add jar")) { + AddJar(sql.drop(8)) + } else if (sql.toLowerCase.startsWith("add file")) { + AddFile(sql.drop(9)) + } else if (sql.startsWith("dfs")) { + DfsCommand(sql) + } else if (sql.startsWith("source")) { + SourceCommand(sql.split(" ").toSeq match { case Seq("source", filePath) => filePath }) + } else if (sql.startsWith("!")) { + ShellCommand(sql.drop(1)) + } else { + val tree = getAst(sql) + + if (nativeCommands contains tree.getText) { + NativeCommand(sql) + } else { + nodeToPlan(tree) match { + case NativePlaceholder => NativeCommand(sql) + case other => other + } + } + } + } catch { + case e: Exception => throw new ParseException(sql, e) + } + } + + def parseDdl(ddl: String): Seq[Attribute] = { + val tree = + try { + ParseUtils.findRootNonNullToken( + (new ParseDriver).parse(ddl, null /* no context required for parsing alone */)) + } catch { + case pe: org.apache.hadoop.hive.ql.parse.ParseException => + throw new RuntimeException(s"Failed to parse ddl: '$ddl'", pe) + } + assert(tree.asInstanceOf[ASTNode].getText == "TOK_CREATETABLE", "Only CREATE TABLE supported.") + val tableOps = tree.getChildren + val colList = + tableOps + .find(_.asInstanceOf[ASTNode].getText == "TOK_TABCOLLIST") + .getOrElse(sys.error("No columnList!")).getChildren + + colList.map(nodeToAttribute) + } + + /** Extractor for matching Hive's AST Tokens. */ + object Token { + /** @return matches of the form (tokenName, children). */ + def unapply(t: Any): Option[(String, Seq[ASTNode])] = t match { + case t: ASTNode => + Some((t.getText, + Option(t.getChildren).map(_.toList).getOrElse(Nil).asInstanceOf[Seq[ASTNode]])) + case _ => None + } + } + + protected def getClauses(clauseNames: Seq[String], nodeList: Seq[ASTNode]): Seq[Option[Node]] = { + var remainingNodes = nodeList + val clauses = clauseNames.map { clauseName => + val (matches, nonMatches) = remainingNodes.partition(_.getText.toUpperCase == clauseName) + remainingNodes = nonMatches ++ (if (matches.nonEmpty) matches.tail else Nil) + matches.headOption + } + + assert(remainingNodes.isEmpty, + s"Unhandled clauses: ${remainingNodes.map(dumpTree(_)).mkString("\n")}") + clauses + } + + def getClause(clauseName: String, nodeList: Seq[Node]) = + getClauseOption(clauseName, nodeList).getOrElse(sys.error( + s"Expected clause $clauseName missing from ${nodeList.map(dumpTree(_)).mkString("\n")}")) + + def getClauseOption(clauseName: String, nodeList: Seq[Node]): Option[Node] = { + nodeList.filter { case ast: ASTNode => ast.getText == clauseName } match { + case Seq(oneMatch) => Some(oneMatch) + case Seq() => None + case _ => sys.error(s"Found multiple instances of clause $clauseName") + } + } + + protected def nodeToAttribute(node: Node): Attribute = node match { + case Token("TOK_TABCOL", Token(colName, Nil) :: dataType :: Nil) => + AttributeReference(colName, nodeToDataType(dataType), true)() + + case a: ASTNode => + throw new NotImplementedError(s"No parse rules for:\n ${dumpTree(a).toString} ") + } + + protected def nodeToDataType(node: Node): DataType = node match { + case Token("TOK_BIGINT", Nil) => IntegerType + case Token("TOK_INT", Nil) => IntegerType + case Token("TOK_TINYINT", Nil) => IntegerType + case Token("TOK_SMALLINT", Nil) => IntegerType + case Token("TOK_BOOLEAN", Nil) => BooleanType + case Token("TOK_STRING", Nil) => StringType + case Token("TOK_FLOAT", Nil) => FloatType + case Token("TOK_DOUBLE", Nil) => FloatType + case Token("TOK_LIST", elementType :: Nil) => ArrayType(nodeToDataType(elementType)) + case Token("TOK_STRUCT", + Token("TOK_TABCOLLIST", fields) :: Nil) => + StructType(fields.map(nodeToStructField)) + case Token("TOK_MAP", + keyType :: + valueType :: Nil) => + MapType(nodeToDataType(keyType), nodeToDataType(valueType)) + case a: ASTNode => + throw new NotImplementedError(s"No parse rules for DataType:\n ${dumpTree(a).toString} ") + } + + protected def nodeToStructField(node: Node): StructField = node match { + case Token("TOK_TABCOL", + Token(fieldName, Nil) :: + dataType :: Nil) => + StructField(fieldName, nodeToDataType(dataType), nullable = true) + case Token("TOK_TABCOL", + Token(fieldName, Nil) :: + dataType :: + _ /* comment */:: Nil) => + StructField(fieldName, nodeToDataType(dataType), nullable = true) + case a: ASTNode => + throw new NotImplementedError(s"No parse rules for StructField:\n ${dumpTree(a).toString} ") + } + + protected def nameExpressions(exprs: Seq[Expression]): Seq[NamedExpression] = { + exprs.zipWithIndex.map { + case (ne: NamedExpression, _) => ne + case (e, i) => Alias(e, s"c_$i")() + } + } + + protected def nodeToPlan(node: Node): LogicalPlan = node match { + // Just fake explain for any of the native commands. + case Token("TOK_EXPLAIN", explainArgs) if nativeCommands contains explainArgs.head.getText => + NoRelation + case Token("TOK_EXPLAIN", explainArgs) => + // Ignore FORMATTED if present. + val Some(query) :: _ :: _ :: Nil = + getClauses(Seq("TOK_QUERY", "FORMATTED", "EXTENDED"), explainArgs) + // TODO: support EXTENDED? + ExplainCommand(nodeToPlan(query)) + + case Token("TOK_CREATETABLE", children) + if children.collect { case t@Token("TOK_QUERY", _) => t }.nonEmpty => + // TODO: Parse other clauses. + // Reference: https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL + val ( + Some(tableNameParts) :: + _ /* likeTable */ :: + Some(query) +: + notImplemented) = + getClauses( + Seq( + "TOK_TABNAME", + "TOK_LIKETABLE", + "TOK_QUERY", + "TOK_IFNOTEXISTS", + "TOK_TABLECOMMENT", + "TOK_TABCOLLIST", + "TOK_TABLEPARTCOLS", // Partitioned by + "TOK_TABLEBUCKETS", // Clustered by + "TOK_TABLESKEWED", // Skewed by + "TOK_TABLEROWFORMAT", + "TOK_TABLESERIALIZER", + "TOK_FILEFORMAT_GENERIC", // For file formats not natively supported by Hive. + "TOK_TBLSEQUENCEFILE", // Stored as SequenceFile + "TOK_TBLTEXTFILE", // Stored as TextFile + "TOK_TBLRCFILE", // Stored as RCFile + "TOK_TBLORCFILE", // Stored as ORC File + "TOK_TABLEFILEFORMAT", // User-provided InputFormat and OutputFormat + "TOK_STORAGEHANDLER", // Storage handler + "TOK_TABLELOCATION", + "TOK_TABLEPROPERTIES"), + children) + if (notImplemented.exists(token => !token.isEmpty)) { + throw new NotImplementedError( + s"Unhandled clauses: ${notImplemented.flatten.map(dumpTree(_)).mkString("\n")}") + } + + val (db, tableName) = + tableNameParts.getChildren.map{ case Token(part, Nil) => cleanIdentifier(part)} match { + case Seq(tableOnly) => (None, tableOnly) + case Seq(databaseName, table) => (Some(databaseName), table) + } + InsertIntoCreatedTable(db, tableName, nodeToPlan(query)) + + // If its not a "CREATE TABLE AS" like above then just pass it back to hive as a native command. + case Token("TOK_CREATETABLE", _) => NativePlaceholder + + case Token("TOK_QUERY", + Token("TOK_FROM", fromClause :: Nil) :: + insertClauses) => + + // Return one query for each insert clause. + val queries = insertClauses.map { case Token("TOK_INSERT", singleInsert) => + val ( + intoClause :: + destClause :: + selectClause :: + selectDistinctClause :: + whereClause :: + groupByClause :: + orderByClause :: + sortByClause :: + clusterByClause :: + distributeByClause :: + limitClause :: + lateralViewClause :: Nil) = { + getClauses( + Seq( + "TOK_INSERT_INTO", + "TOK_DESTINATION", + "TOK_SELECT", + "TOK_SELECTDI", + "TOK_WHERE", + "TOK_GROUPBY", + "TOK_ORDERBY", + "TOK_SORTBY", + "TOK_CLUSTERBY", + "TOK_DISTRIBUTEBY", + "TOK_LIMIT", + "TOK_LATERAL_VIEW"), + singleInsert) + } + + val relations = nodeToRelation(fromClause) + val withWhere = whereClause.map { whereNode => + val Seq(whereExpr) = whereNode.getChildren.toSeq + Filter(nodeToExpr(whereExpr), relations) + }.getOrElse(relations) + + val select = + (selectClause orElse selectDistinctClause).getOrElse(sys.error("No select clause.")) + + // Script transformations are expressed as a select clause with a single expression of type + // TOK_TRANSFORM + val transformation = select.getChildren.head match { + case Token("TOK_SELEXPR", + Token("TOK_TRANSFORM", + Token("TOK_EXPLIST", inputExprs) :: + Token("TOK_SERDE", Nil) :: + Token("TOK_RECORDWRITER", writerClause) :: + // TODO: Need to support other types of (in/out)put + Token(script, Nil) :: + Token("TOK_SERDE", serdeClause) :: + Token("TOK_RECORDREADER", readerClause) :: + outputClause :: Nil) :: Nil) => + + val output = outputClause match { + case Token("TOK_ALIASLIST", aliases) => + aliases.map { case Token(name, Nil) => AttributeReference(name, StringType)() } + case Token("TOK_TABCOLLIST", attributes) => + attributes.map { case Token("TOK_TABCOL", Token(name, Nil) :: dataType :: Nil) => + AttributeReference(name, nodeToDataType(dataType))() } + } + val unescapedScript = BaseSemanticAnalyzer.unescapeSQLString(script) + + Some( + logical.ScriptTransformation( + inputExprs.map(nodeToExpr), + unescapedScript, + output, + withWhere)) + case _ => None + } + + val withLateralView = lateralViewClause.map { lv => + val Token("TOK_SELECT", + Token("TOK_SELEXPR", clauses) :: Nil) = lv.getChildren.head + + val alias = + getClause("TOK_TABALIAS", clauses).getChildren.head.asInstanceOf[ASTNode].getText + + Generate( + nodesToGenerator(clauses), + join = true, + outer = false, + Some(alias.toLowerCase), + withWhere) + }.getOrElse(withWhere) + + + // The projection of the query can either be a normal projection, an aggregation + // (if there is a group by) or a script transformation. + val withProject = transformation.getOrElse { + // Not a transformation so must be either project or aggregation. + val selectExpressions = nameExpressions(select.getChildren.flatMap(selExprNodeToExpr)) + + groupByClause match { + case Some(groupBy) => + Aggregate(groupBy.getChildren.map(nodeToExpr), selectExpressions, withLateralView) + case None => + Project(selectExpressions, withLateralView) + } + } + + val withDistinct = + if (selectDistinctClause.isDefined) Distinct(withProject) else withProject + + val withSort = + (orderByClause, sortByClause, distributeByClause, clusterByClause) match { + case (Some(totalOrdering), None, None, None) => + Sort(totalOrdering.getChildren.map(nodeToSortOrder), withDistinct) + case (None, Some(perPartitionOrdering), None, None) => + SortPartitions(perPartitionOrdering.getChildren.map(nodeToSortOrder), withDistinct) + case (None, None, Some(partitionExprs), None) => + Repartition(partitionExprs.getChildren.map(nodeToExpr), withDistinct) + case (None, Some(perPartitionOrdering), Some(partitionExprs), None) => + SortPartitions(perPartitionOrdering.getChildren.map(nodeToSortOrder), + Repartition(partitionExprs.getChildren.map(nodeToExpr), withDistinct)) + case (None, None, None, Some(clusterExprs)) => + SortPartitions(clusterExprs.getChildren.map(nodeToExpr).map(SortOrder(_, Ascending)), + Repartition(clusterExprs.getChildren.map(nodeToExpr), withDistinct)) + case (None, None, None, None) => withDistinct + case _ => sys.error("Unsupported set of ordering / distribution clauses.") + } + + val withLimit = + limitClause.map(l => nodeToExpr(l.getChildren.head)) + .map(StopAfter(_, withSort)) + .getOrElse(withSort) + + // TOK_INSERT_INTO means to add files to the table. + // TOK_DESTINATION means to overwrite the table. + val resultDestination = + (intoClause orElse destClause).getOrElse(sys.error("No destination found.")) + val overwrite = if (intoClause.isEmpty) true else false + nodeToDest( + resultDestination, + withLimit, + overwrite) + } + + // If there are multiple INSERTS just UNION them together into on query. + queries.reduceLeft(Union) + + case Token("TOK_UNION", left :: right :: Nil) => Union(nodeToPlan(left), nodeToPlan(right)) + + case a: ASTNode => + throw new NotImplementedError(s"No parse rules for:\n ${dumpTree(a).toString} ") + } + + val allJoinTokens = "(TOK_.*JOIN)".r + val laterViewToken = "TOK_LATERAL_VIEW(.*)".r + def nodeToRelation(node: Node): LogicalPlan = node match { + case Token("TOK_SUBQUERY", + query :: Token(alias, Nil) :: Nil) => + Subquery(alias, nodeToPlan(query)) + + case Token(laterViewToken(isOuter), selectClause :: relationClause :: Nil) => + val Token("TOK_SELECT", + Token("TOK_SELEXPR", clauses) :: Nil) = selectClause + + val alias = getClause("TOK_TABALIAS", clauses).getChildren.head.asInstanceOf[ASTNode].getText + + Generate( + nodesToGenerator(clauses), + join = true, + outer = isOuter.nonEmpty, + Some(alias.toLowerCase), + nodeToRelation(relationClause)) + + /* All relations, possibly with aliases or sampling clauses. */ + case Token("TOK_TABREF", clauses) => + // If the last clause is not a token then it's the alias of the table. + val (nonAliasClauses, aliasClause) = + if (clauses.last.getText.startsWith("TOK")) { + (clauses, None) + } else { + (clauses.dropRight(1), Some(clauses.last)) + } + + val (Some(tableNameParts) :: + splitSampleClause :: + bucketSampleClause :: Nil) = { + getClauses(Seq("TOK_TABNAME", "TOK_TABLESPLITSAMPLE", "TOK_TABLEBUCKETSAMPLE"), + nonAliasClauses) + } + + val (db, tableName) = + tableNameParts.getChildren.map{ case Token(part, Nil) => cleanIdentifier(part)} match { + case Seq(tableOnly) => (None, tableOnly) + case Seq(databaseName, table) => (Some(databaseName), table) + } + val alias = aliasClause.map { case Token(a, Nil) => cleanIdentifier(a) } + val relation = UnresolvedRelation(db, tableName, alias) + + // Apply sampling if requested. + (bucketSampleClause orElse splitSampleClause).map { + case Token("TOK_TABLESPLITSAMPLE", + Token("TOK_ROWCOUNT", Nil) :: + Token(count, Nil) :: Nil) => + StopAfter(Literal(count.toInt), relation) + case Token("TOK_TABLESPLITSAMPLE", + Token("TOK_PERCENT", Nil) :: + Token(fraction, Nil) :: Nil) => + Sample(fraction.toDouble, withReplacement = false, (math.random * 1000).toInt, relation) + case Token("TOK_TABLEBUCKETSAMPLE", + Token(numerator, Nil) :: + Token(denominator, Nil) :: Nil) => + val fraction = numerator.toDouble / denominator.toDouble + Sample(fraction, withReplacement = false, (math.random * 1000).toInt, relation) + case a: ASTNode => + throw new NotImplementedError( + s"""No parse rules for sampling clause: ${a.getType}, text: ${a.getText} : + |${dumpTree(a).toString}" + + """.stripMargin) + }.getOrElse(relation) + + case Token("TOK_UNIQUEJOIN", joinArgs) => + val tableOrdinals = + joinArgs.zipWithIndex.filter { + case (arg, i) => arg.getText == "TOK_TABREF" + }.map(_._2) + + val isPreserved = tableOrdinals.map(i => (i - 1 < 0) || joinArgs(i - 1).getText == "PRESERVE") + val tables = tableOrdinals.map(i => nodeToRelation(joinArgs(i))) + val joinExpressions = tableOrdinals.map(i => joinArgs(i + 1).getChildren.map(nodeToExpr)) + + val joinConditions = joinExpressions.sliding(2).map { + case Seq(c1, c2) => + val predicates = (c1, c2).zipped.map { case (e1, e2) => Equals(e1, e2): Expression } + predicates.reduceLeft(And) + }.toBuffer + + val joinType = isPreserved.sliding(2).map { + case Seq(true, true) => FullOuter + case Seq(true, false) => LeftOuter + case Seq(false, true) => RightOuter + case Seq(false, false) => Inner + }.toBuffer + + val joinedTables = tables.reduceLeft(Join(_,_, Inner, None)) + + // Must be transform down. + val joinedResult = joinedTables transform { + case j: Join => + j.copy( + condition = Some(joinConditions.remove(joinConditions.length - 1)), + joinType = joinType.remove(joinType.length - 1)) + } + + val groups = (0 until joinExpressions.head.size).map(i => Coalesce(joinExpressions.map(_(i)))) + + // Unique join is not really the same as an outer join so we must group together results where + // the joinExpressions are the same, taking the First of each value is only okay because the + // user of a unique join is implicitly promising that there is only one result. + // TODO: This doesn't actually work since [[Star]] is not a valid aggregate expression. + // instead we should figure out how important supporting this feature is and whether it is + // worth the number of hacks that will be required to implement it. Namely, we need to add + // some sort of mapped star expansion that would expand all child output row to be similarly + // named output expressions where some aggregate expression has been applied (i.e. First). + ??? /// Aggregate(groups, Star(None, First(_)) :: Nil, joinedResult) + + case Token(allJoinTokens(joinToken), + relation1 :: + relation2 :: other) => + assert(other.size <= 1, s"Unhandled join child ${other}") + val joinType = joinToken match { + case "TOK_JOIN" => Inner + case "TOK_RIGHTOUTERJOIN" => RightOuter + case "TOK_LEFTOUTERJOIN" => LeftOuter + case "TOK_FULLOUTERJOIN" => FullOuter + } + assert(other.size <= 1, "Unhandled join clauses.") + Join(nodeToRelation(relation1), + nodeToRelation(relation2), + joinType, + other.headOption.map(nodeToExpr)) + + case a: ASTNode => + throw new NotImplementedError(s"No parse rules for:\n ${dumpTree(a).toString} ") + } + + def nodeToSortOrder(node: Node): SortOrder = node match { + case Token("TOK_TABSORTCOLNAMEASC", sortExpr :: Nil) => + SortOrder(nodeToExpr(sortExpr), Ascending) + case Token("TOK_TABSORTCOLNAMEDESC", sortExpr :: Nil) => + SortOrder(nodeToExpr(sortExpr), Descending) + + case a: ASTNode => + throw new NotImplementedError(s"No parse rules for:\n ${dumpTree(a).toString} ") + } + + val destinationToken = "TOK_DESTINATION|TOK_INSERT_INTO".r + protected def nodeToDest( + node: Node, + query: LogicalPlan, + overwrite: Boolean): LogicalPlan = node match { + case Token(destinationToken(), + Token("TOK_DIR", + Token("TOK_TMP_FILE", Nil) :: Nil) :: Nil) => + query + + case Token(destinationToken(), + Token("TOK_TAB", + tableArgs) :: Nil) => + val Some(tableNameParts) :: partitionClause :: Nil = + getClauses(Seq("TOK_TABNAME", "TOK_PARTSPEC"), tableArgs) + + val (db, tableName) = + tableNameParts.getChildren.map{ case Token(part, Nil) => cleanIdentifier(part)} match { + case Seq(tableOnly) => (None, tableOnly) + case Seq(databaseName, table) => (Some(databaseName), table) + } + + val partitionKeys = partitionClause.map(_.getChildren.map { + // Parse partitions. We also make keys case insensitive. + case Token("TOK_PARTVAL", Token(key, Nil) :: Token(value, Nil) :: Nil) => + cleanIdentifier(key.toLowerCase) -> Some(PlanUtils.stripQuotes(value)) + case Token("TOK_PARTVAL", Token(key, Nil) :: Nil) => + cleanIdentifier(key.toLowerCase) -> None + }.toMap).getOrElse(Map.empty) + + if (partitionKeys.values.exists(p => p.isEmpty)) { + throw new NotImplementedError(s"Do not support INSERT INTO/OVERWRITE with" + + s"dynamic partitioning.") + } + + InsertIntoTable(UnresolvedRelation(db, tableName, None), partitionKeys, query, overwrite) + + case a: ASTNode => + throw new NotImplementedError(s"No parse rules for:\n ${dumpTree(a).toString} ") + } + + protected def selExprNodeToExpr(node: Node): Option[Expression] = node match { + case Token("TOK_SELEXPR", + e :: Nil) => + Some(nodeToExpr(e)) + + case Token("TOK_SELEXPR", + e :: Token(alias, Nil) :: Nil) => + Some(Alias(nodeToExpr(e), alias)()) + + /* Hints are ignored */ + case Token("TOK_HINTLIST", _) => None + + case a: ASTNode => + throw new NotImplementedError(s"No parse rules for:\n ${dumpTree(a).toString} ") + } + + + protected val escapedIdentifier = "`([^`]+)`".r + /** Strips backticks from ident if present */ + protected def cleanIdentifier(ident: String): String = ident match { + case escapedIdentifier(i) => i + case plainIdent => plainIdent + } + + val numericAstTypes = Seq( + HiveParser.Number, + HiveParser.TinyintLiteral, + HiveParser.SmallintLiteral, + HiveParser.BigintLiteral) + + /* Case insensitive matches */ + val COUNT = "(?i)COUNT".r + val AVG = "(?i)AVG".r + val SUM = "(?i)SUM".r + val RAND = "(?i)RAND".r + val AND = "(?i)AND".r + val OR = "(?i)OR".r + val NOT = "(?i)NOT".r + val TRUE = "(?i)TRUE".r + val FALSE = "(?i)FALSE".r + + protected def nodeToExpr(node: Node): Expression = node match { + /* Attribute References */ + case Token("TOK_TABLE_OR_COL", + Token(name, Nil) :: Nil) => + UnresolvedAttribute(cleanIdentifier(name)) + case Token(".", qualifier :: Token(attr, Nil) :: Nil) => + nodeToExpr(qualifier) match { + case UnresolvedAttribute(qualifierName) => + UnresolvedAttribute(qualifierName + "." + cleanIdentifier(attr)) + // The precidence for . seems to be wrong, so [] binds tighter an we need to go inside to + // find the underlying attribute references. + case GetItem(UnresolvedAttribute(qualifierName), ordinal) => + GetItem(UnresolvedAttribute(qualifierName + "." + cleanIdentifier(attr)), ordinal) + } + + /* Stars (*) */ + case Token("TOK_ALLCOLREF", Nil) => Star(None) + // The format of dbName.tableName.* cannot be parsed by HiveParser. TOK_TABNAME will only + // has a single child which is tableName. + case Token("TOK_ALLCOLREF", Token("TOK_TABNAME", Token(name, Nil) :: Nil) :: Nil) => + Star(Some(name)) + + /* Aggregate Functions */ + case Token("TOK_FUNCTION", Token(AVG(), Nil) :: arg :: Nil) => Average(nodeToExpr(arg)) + case Token("TOK_FUNCTION", Token(COUNT(), Nil) :: arg :: Nil) => Count(nodeToExpr(arg)) + case Token("TOK_FUNCTIONSTAR", Token(COUNT(), Nil) :: Nil) => Count(Literal(1)) + case Token("TOK_FUNCTIONDI", Token(COUNT(), Nil) :: args) => CountDistinct(args.map(nodeToExpr)) + case Token("TOK_FUNCTION", Token(SUM(), Nil) :: arg :: Nil) => Sum(nodeToExpr(arg)) + case Token("TOK_FUNCTIONDI", Token(SUM(), Nil) :: arg :: Nil) => SumDistinct(nodeToExpr(arg)) + + /* Casts */ + case Token("TOK_FUNCTION", Token("TOK_STRING", Nil) :: arg :: Nil) => + Cast(nodeToExpr(arg), StringType) + case Token("TOK_FUNCTION", Token("TOK_VARCHAR", Nil) :: arg :: Nil) => + Cast(nodeToExpr(arg), StringType) + case Token("TOK_FUNCTION", Token("TOK_INT", Nil) :: arg :: Nil) => + Cast(nodeToExpr(arg), IntegerType) + case Token("TOK_FUNCTION", Token("TOK_BIGINT", Nil) :: arg :: Nil) => + Cast(nodeToExpr(arg), LongType) + case Token("TOK_FUNCTION", Token("TOK_FLOAT", Nil) :: arg :: Nil) => + Cast(nodeToExpr(arg), FloatType) + case Token("TOK_FUNCTION", Token("TOK_DOUBLE", Nil) :: arg :: Nil) => + Cast(nodeToExpr(arg), DoubleType) + case Token("TOK_FUNCTION", Token("TOK_SMALLINT", Nil) :: arg :: Nil) => + Cast(nodeToExpr(arg), ShortType) + case Token("TOK_FUNCTION", Token("TOK_TINYINT", Nil) :: arg :: Nil) => + Cast(nodeToExpr(arg), ByteType) + case Token("TOK_FUNCTION", Token("TOK_BINARY", Nil) :: arg :: Nil) => + Cast(nodeToExpr(arg), BinaryType) + case Token("TOK_FUNCTION", Token("TOK_BOOLEAN", Nil) :: arg :: Nil) => + Cast(nodeToExpr(arg), BooleanType) + case Token("TOK_FUNCTION", Token("TOK_DECIMAL", Nil) :: arg :: Nil) => + Cast(nodeToExpr(arg), DecimalType) + + /* Arithmetic */ + case Token("-", child :: Nil) => UnaryMinus(nodeToExpr(child)) + case Token("+", left :: right:: Nil) => Add(nodeToExpr(left), nodeToExpr(right)) + case Token("-", left :: right:: Nil) => Subtract(nodeToExpr(left), nodeToExpr(right)) + case Token("*", left :: right:: Nil) => Multiply(nodeToExpr(left), nodeToExpr(right)) + case Token("/", left :: right:: Nil) => Divide(nodeToExpr(left), nodeToExpr(right)) + case Token("DIV", left :: right:: Nil) => Divide(nodeToExpr(left), nodeToExpr(right)) + case Token("%", left :: right:: Nil) => Remainder(nodeToExpr(left), nodeToExpr(right)) + + /* Comparisons */ + case Token("=", left :: right:: Nil) => Equals(nodeToExpr(left), nodeToExpr(right)) + case Token("!=", left :: right:: Nil) => Not(Equals(nodeToExpr(left), nodeToExpr(right))) + case Token("<>", left :: right:: Nil) => Not(Equals(nodeToExpr(left), nodeToExpr(right))) + case Token(">", left :: right:: Nil) => GreaterThan(nodeToExpr(left), nodeToExpr(right)) + case Token(">=", left :: right:: Nil) => GreaterThanOrEqual(nodeToExpr(left), nodeToExpr(right)) + case Token("<", left :: right:: Nil) => LessThan(nodeToExpr(left), nodeToExpr(right)) + case Token("<=", left :: right:: Nil) => LessThanOrEqual(nodeToExpr(left), nodeToExpr(right)) + case Token("LIKE", left :: right:: Nil) => + UnresolvedFunction("LIKE", Seq(nodeToExpr(left), nodeToExpr(right))) + case Token("RLIKE", left :: right:: Nil) => + UnresolvedFunction("RLIKE", Seq(nodeToExpr(left), nodeToExpr(right))) + case Token("REGEXP", left :: right:: Nil) => + UnresolvedFunction("REGEXP", Seq(nodeToExpr(left), nodeToExpr(right))) + case Token("TOK_FUNCTION", Token("TOK_ISNOTNULL", Nil) :: child :: Nil) => + IsNotNull(nodeToExpr(child)) + case Token("TOK_FUNCTION", Token("TOK_ISNULL", Nil) :: child :: Nil) => + IsNull(nodeToExpr(child)) + case Token("TOK_FUNCTION", Token("IN", Nil) :: value :: list) => + In(nodeToExpr(value), list.map(nodeToExpr)) + + /* Boolean Logic */ + case Token(AND(), left :: right:: Nil) => And(nodeToExpr(left), nodeToExpr(right)) + case Token(OR(), left :: right:: Nil) => Or(nodeToExpr(left), nodeToExpr(right)) + case Token(NOT(), child :: Nil) => Not(nodeToExpr(child)) + + /* Complex datatype manipulation */ + case Token("[", child :: ordinal :: Nil) => + GetItem(nodeToExpr(child), nodeToExpr(ordinal)) + + /* Other functions */ + case Token("TOK_FUNCTION", Token(RAND(), Nil) :: Nil) => Rand + + /* UDFs - Must be last otherwise will preempt built in functions */ + case Token("TOK_FUNCTION", Token(name, Nil) :: args) => + UnresolvedFunction(name, args.map(nodeToExpr)) + case Token("TOK_FUNCTIONSTAR", Token(name, Nil) :: args) => + UnresolvedFunction(name, Star(None) :: Nil) + + /* Literals */ + case Token("TOK_NULL", Nil) => Literal(null, NullType) + case Token(TRUE(), Nil) => Literal(true, BooleanType) + case Token(FALSE(), Nil) => Literal(false, BooleanType) + case Token("TOK_STRINGLITERALSEQUENCE", strings) => + Literal(strings.map(s => BaseSemanticAnalyzer.unescapeSQLString(s.getText)).mkString) + + // This code is adapted from + // /ql/src/java/org/apache/hadoop/hive/ql/parse/TypeCheckProcFactory.java#L223 + case ast: ASTNode if numericAstTypes contains ast.getType => + var v: Literal = null + try { + if (ast.getText.endsWith("L")) { + // Literal bigint. + v = Literal(ast.getText.substring(0, ast.getText.length() - 1).toLong, LongType) + } else if (ast.getText.endsWith("S")) { + // Literal smallint. + v = Literal(ast.getText.substring(0, ast.getText.length() - 1).toShort, ShortType) + } else if (ast.getText.endsWith("Y")) { + // Literal tinyint. + v = Literal(ast.getText.substring(0, ast.getText.length() - 1).toByte, ByteType) + } else if (ast.getText.endsWith("BD")) { + // Literal decimal + val strVal = ast.getText.substring(0, ast.getText.length() - 2) + BigDecimal(strVal) + } else { + v = Literal(ast.getText.toDouble, DoubleType) + v = Literal(ast.getText.toLong, LongType) + v = Literal(ast.getText.toInt, IntegerType) + } + } catch { + case nfe: NumberFormatException => // Do nothing + } + + if (v == null) { + sys.error(s"Failed to parse number ${ast.getText}") + } else { + v + } + + case ast: ASTNode if ast.getType == HiveParser.StringLiteral => + Literal(BaseSemanticAnalyzer.unescapeSQLString(ast.getText)) + + case a: ASTNode => + throw new NotImplementedError( + s"""No parse rules for ASTNode type: ${a.getType}, text: ${a.getText} : + |${dumpTree(a).toString}" + + """.stripMargin) + } + + + val explode = "(?i)explode".r + def nodesToGenerator(nodes: Seq[Node]): Generator = { + val function = nodes.head + + val attributes = nodes.flatMap { + case Token(a, Nil) => a.toLowerCase :: Nil + case _ => Nil + } + + function match { + case Token("TOK_FUNCTION", Token(explode(), Nil) :: child :: Nil) => + Explode(attributes, nodeToExpr(child)) + + case Token("TOK_FUNCTION", Token(functionName, Nil) :: children) => + HiveGenericUdtf(functionName, attributes, children.map(nodeToExpr)) + + case a: ASTNode => + throw new NotImplementedError( + s"""No parse rules for ASTNode type: ${a.getType}, text: ${a.getText}, tree: + |${dumpTree(a).toString} + """.stripMargin) + } + } + + def dumpTree(node: Node, builder: StringBuilder = new StringBuilder, indent: Int = 0) + : StringBuilder = { + node match { + case a: ASTNode => builder.append((" " * indent) + a.getText + "\n") + case other => sys.error(s"Non ASTNode encountered: $other") + } + + Option(node.getChildren).map(_.toList).getOrElse(Nil).foreach(dumpTree(_, builder, indent + 1)) + builder + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala new file mode 100644 index 000000000..92d84208a --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package hive + +import catalyst.expressions._ +import catalyst.planning._ +import catalyst.plans._ +import catalyst.plans.logical.{BaseRelation, LogicalPlan} + +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.parquet.{ParquetRelation, InsertIntoParquetTable, ParquetTableScan} + +trait HiveStrategies { + // Possibly being too clever with types here... or not clever enough. + self: SQLContext#SparkPlanner => + + val hiveContext: HiveContext + + object Scripts extends Strategy { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case logical.ScriptTransformation(input, script, output, child) => + ScriptTransformation(input, script, output, planLater(child))(hiveContext) :: Nil + case _ => Nil + } + } + + object DataSinks extends Strategy { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case logical.InsertIntoTable(table: MetastoreRelation, partition, child, overwrite) => + InsertIntoHiveTable(table, partition, planLater(child), overwrite)(hiveContext) :: Nil + case logical.InsertIntoTable(table: ParquetRelation, partition, child, overwrite) => + InsertIntoParquetTable(table, planLater(child))(hiveContext.sparkContext) :: Nil + case _ => Nil + } + } + + object HiveTableScans extends Strategy { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + // Push attributes into table scan when possible. + case p @ logical.Project(projectList, m: MetastoreRelation) if isSimpleProject(projectList) => + HiveTableScan(projectList.asInstanceOf[Seq[Attribute]], m, None)(hiveContext) :: Nil + case m: MetastoreRelation => + HiveTableScan(m.output, m, None)(hiveContext) :: Nil + case _ => Nil + } + } + + /** + * A strategy used to detect filtering predicates on top of a partitioned relation to help + * partition pruning. + * + * This strategy itself doesn't perform partition pruning, it just collects and combines all the + * partition pruning predicates and pass them down to the underlying [[HiveTableScan]] operator, + * which does the actual pruning work. + */ + object PartitionPrunings extends Strategy { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case p @ FilteredOperation(predicates, relation: MetastoreRelation) + if relation.isPartitioned => + + val partitionKeyIds = relation.partitionKeys.map(_.id).toSet + + // Filter out all predicates that only deal with partition keys + val (pruningPredicates, otherPredicates) = predicates.partition { + _.references.map(_.id).subsetOf(partitionKeyIds) + } + + val scan = HiveTableScan( + relation.output, relation, pruningPredicates.reduceLeftOption(And))(hiveContext) + + otherPredicates + .reduceLeftOption(And) + .map(Filter(_, scan)) + .getOrElse(scan) :: Nil + + case _ => + Nil + } + } + + /** + * A strategy that detects projects and filters over some relation and applies column pruning if + * possible. Partition pruning is applied first if the relation is partitioned. + */ + object ColumnPrunings extends Strategy { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + // TODO(andre): the current mix of HiveRelation and ParquetRelation + // here appears artificial; try to refactor to break it into two + case PhysicalOperation(projectList, predicates, relation: BaseRelation) => + val predicateOpt = predicates.reduceOption(And) + val predicateRefs = predicateOpt.map(_.references).getOrElse(Set.empty) + val projectRefs = projectList.flatMap(_.references) + + // To figure out what columns to preserve after column pruning, we need to consider: + // + // 1. Columns referenced by the project list (order preserved) + // 2. Columns referenced by filtering predicates but not by project list + // 3. Relation output + // + // Then the final result is ((1 union 2) intersect 3) + val prunedCols = (projectRefs ++ (predicateRefs -- projectRefs)).intersect(relation.output) + + val filteredScans = + if (relation.isPartitioned) { // from here on relation must be a [[MetaStoreRelation]] + // Applies partition pruning first for partitioned table + val filteredRelation = predicateOpt.map(logical.Filter(_, relation)).getOrElse(relation) + PartitionPrunings(filteredRelation).view.map(_.transform { + case scan: HiveTableScan => + scan.copy(attributes = prunedCols)(hiveContext) + }) + } else { + val scan = relation match { + case MetastoreRelation(_, _, _) => { + HiveTableScan( + prunedCols, + relation.asInstanceOf[MetastoreRelation], + None)(hiveContext) + } + case ParquetRelation(_, _) => { + ParquetTableScan( + relation.output, + relation.asInstanceOf[ParquetRelation], + None)(hiveContext.sparkContext) + .pruneColumns(prunedCols) + } + } + predicateOpt.map(execution.Filter(_, scan)).getOrElse(scan) :: Nil + } + + if (isSimpleProject(projectList) && prunedCols == projectRefs) { + filteredScans + } else { + filteredScans.view.map(execution.Project(projectList, _)) + } + + case _ => + Nil + } + } + + /** + * Returns true if `projectList` only performs column pruning and does not evaluate other + * complex expressions. + */ + def isSimpleProject(projectList: Seq[NamedExpression]) = { + projectList.forall(_.isInstanceOf[Attribute]) + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/ScriptTransformation.scala new file mode 100644 index 000000000..f20e9d4de --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/ScriptTransformation.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package hive + +import java.io.{InputStreamReader, BufferedReader} + +import catalyst.expressions._ +import org.apache.spark.sql.execution._ + +import scala.collection.JavaConversions._ + +/** + * Transforms the input by forking and running the specified script. + * + * @param input the set of expression that should be passed to the script. + * @param script the command that should be executed. + * @param output the attributes that are produced by the script. + */ +case class ScriptTransformation( + input: Seq[Expression], + script: String, + output: Seq[Attribute], + child: SparkPlan)(@transient sc: HiveContext) + extends UnaryNode { + + override def otherCopyArgs = sc :: Nil + + def execute() = { + child.execute().mapPartitions { iter => + val cmd = List("/bin/bash", "-c", script) + val builder = new ProcessBuilder(cmd) + val proc = builder.start() + val inputStream = proc.getInputStream + val outputStream = proc.getOutputStream + val reader = new BufferedReader(new InputStreamReader(inputStream)) + + // TODO: This should be exposed as an iterator instead of reading in all the data at once. + val outputLines = collection.mutable.ArrayBuffer[Row]() + val readerThread = new Thread("Transform OutputReader") { + override def run() { + var curLine = reader.readLine() + while (curLine != null) { + // TODO: Use SerDe + outputLines += new GenericRow(curLine.split("\t").asInstanceOf[Array[Any]]) + curLine = reader.readLine() + } + } + } + readerThread.start() + val outputProjection = new Projection(input) + iter + .map(outputProjection) + // TODO: Use SerDe + .map(_.mkString("", "\t", "\n").getBytes).foreach(outputStream.write) + outputStream.close() + readerThread.join() + outputLines.toIterator + } + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala new file mode 100644 index 000000000..71d751cbc --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -0,0 +1,243 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package hive + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.hive.metastore.api.hive_metastoreConstants._ +import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition, Table => HiveTable} +import org.apache.hadoop.hive.ql.plan.TableDesc +import org.apache.hadoop.hive.serde2.Deserializer +import org.apache.hadoop.hive.ql.exec.Utilities +import org.apache.hadoop.io.Writable +import org.apache.hadoop.fs.{Path, PathFilter} +import org.apache.hadoop.mapred.{FileInputFormat, JobConf, InputFormat} + +import org.apache.spark.SerializableWritable +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.rdd.{HadoopRDD, UnionRDD, EmptyRDD, RDD} + + +/** + * A trait for subclasses that handle table scans. + */ +private[hive] sealed trait TableReader { + def makeRDDForTable(hiveTable: HiveTable): RDD[_] + + def makeRDDForPartitionedTable(partitions: Seq[HivePartition]): RDD[_] + +} + + +/** + * Helper class for scanning tables stored in Hadoop - e.g., to read Hive tables that reside in the + * data warehouse directory. + */ +private[hive] +class HadoopTableReader(@transient _tableDesc: TableDesc, @transient sc: HiveContext) + extends TableReader { + + // Choose the minimum number of splits. If mapred.map.tasks is set, then use that unless + // it is smaller than what Spark suggests. + private val _minSplitsPerRDD = math.max( + sc.hiveconf.getInt("mapred.map.tasks", 1), sc.sparkContext.defaultMinSplits) + + + // TODO: set aws s3 credentials. + + private val _broadcastedHiveConf = + sc.sparkContext.broadcast(new SerializableWritable(sc.hiveconf)) + + def broadcastedHiveConf = _broadcastedHiveConf + + def hiveConf = _broadcastedHiveConf.value.value + + override def makeRDDForTable(hiveTable: HiveTable): RDD[_] = + makeRDDForTable( + hiveTable, + _tableDesc.getDeserializerClass.asInstanceOf[Class[Deserializer]], + filterOpt = None) + + /** + * Creates a Hadoop RDD to read data from the target table's data directory. Returns a transformed + * RDD that contains deserialized rows. + * + * @param hiveTable Hive metadata for the table being scanned. + * @param deserializerClass Class of the SerDe used to deserialize Writables read from Hadoop. + * @param filterOpt If defined, then the filter is used to reject files contained in the data + * directory being read. If None, then all files are accepted. + */ + def makeRDDForTable( + hiveTable: HiveTable, + deserializerClass: Class[_ <: Deserializer], + filterOpt: Option[PathFilter]): RDD[_] = + { + assert(!hiveTable.isPartitioned, """makeRDDForTable() cannot be called on a partitioned table, + since input formats may differ across partitions. Use makeRDDForTablePartitions() instead.""") + + // Create local references to member variables, so that the entire `this` object won't be + // serialized in the closure below. + val tableDesc = _tableDesc + val broadcastedHiveConf = _broadcastedHiveConf + + val tablePath = hiveTable.getPath + val inputPathStr = applyFilterIfNeeded(tablePath, filterOpt) + + //logDebug("Table input: %s".format(tablePath)) + val ifc = hiveTable.getInputFormatClass + .asInstanceOf[java.lang.Class[InputFormat[Writable, Writable]]] + val hadoopRDD = createHadoopRdd(tableDesc, inputPathStr, ifc) + + val deserializedHadoopRDD = hadoopRDD.mapPartitions { iter => + val hconf = broadcastedHiveConf.value.value + val deserializer = deserializerClass.newInstance() + deserializer.initialize(hconf, tableDesc.getProperties) + + // Deserialize each Writable to get the row value. + iter.map { + case v: Writable => deserializer.deserialize(v) + case value => + sys.error(s"Unable to deserialize non-Writable: $value of ${value.getClass.getName}") + } + } + deserializedHadoopRDD + } + + override def makeRDDForPartitionedTable(partitions: Seq[HivePartition]): RDD[_] = { + val partitionToDeserializer = partitions.map(part => + (part, part.getDeserializer.getClass.asInstanceOf[Class[Deserializer]])).toMap + makeRDDForPartitionedTable(partitionToDeserializer, filterOpt = None) + } + + /** + * Create a HadoopRDD for every partition key specified in the query. Note that for on-disk Hive + * tables, a data directory is created for each partition corresponding to keys specified using + * 'PARTITION BY'. + * + * @param partitionToDeserializer Mapping from a Hive Partition metadata object to the SerDe + * class to use to deserialize input Writables from the corresponding partition. + * @param filterOpt If defined, then the filter is used to reject files contained in the data + * subdirectory of each partition being read. If None, then all files are accepted. + */ + def makeRDDForPartitionedTable( + partitionToDeserializer: Map[HivePartition, Class[_ <: Deserializer]], + filterOpt: Option[PathFilter]): RDD[_] = + { + val hivePartitionRDDs = partitionToDeserializer.map { case (partition, partDeserializer) => + val partDesc = Utilities.getPartitionDesc(partition) + val partPath = partition.getPartitionPath + val inputPathStr = applyFilterIfNeeded(partPath, filterOpt) + val ifc = partDesc.getInputFileFormatClass + .asInstanceOf[java.lang.Class[InputFormat[Writable, Writable]]] + // Get partition field info + val partSpec = partDesc.getPartSpec + val partProps = partDesc.getProperties + + val partColsDelimited: String = partProps.getProperty(META_TABLE_PARTITION_COLUMNS) + // Partitioning columns are delimited by "/" + val partCols = partColsDelimited.trim().split("/").toSeq + // 'partValues[i]' contains the value for the partitioning column at 'partCols[i]'. + val partValues = if (partSpec == null) { + Array.fill(partCols.size)(new String) + } else { + partCols.map(col => new String(partSpec.get(col))).toArray + } + + // Create local references so that the outer object isn't serialized. + val tableDesc = _tableDesc + val broadcastedHiveConf = _broadcastedHiveConf + val localDeserializer = partDeserializer + + val hivePartitionRDD = createHadoopRdd(tableDesc, inputPathStr, ifc) + hivePartitionRDD.mapPartitions { iter => + val hconf = broadcastedHiveConf.value.value + val rowWithPartArr = new Array[Object](2) + // Map each tuple to a row object + iter.map { value => + val deserializer = localDeserializer.newInstance() + deserializer.initialize(hconf, partProps) + val deserializedRow = deserializer.deserialize(value) + rowWithPartArr.update(0, deserializedRow) + rowWithPartArr.update(1, partValues) + rowWithPartArr.asInstanceOf[Object] + } + } + }.toSeq + // Even if we don't use any partitions, we still need an empty RDD + if (hivePartitionRDDs.size == 0) { + new EmptyRDD[Object](sc.sparkContext) + } else { + new UnionRDD(hivePartitionRDDs(0).context, hivePartitionRDDs) + } + } + + /** + * If `filterOpt` is defined, then it will be used to filter files from `path`. These files are + * returned in a single, comma-separated string. + */ + private def applyFilterIfNeeded(path: Path, filterOpt: Option[PathFilter]): String = { + filterOpt match { + case Some(filter) => + val fs = path.getFileSystem(sc.hiveconf) + val filteredFiles = fs.listStatus(path, filter).map(_.getPath.toString) + filteredFiles.mkString(",") + case None => path.toString + } + } + + /** + * Creates a HadoopRDD based on the broadcasted HiveConf and other job properties that will be + * applied locally on each slave. + */ + private def createHadoopRdd( + tableDesc: TableDesc, + path: String, + inputFormatClass: Class[InputFormat[Writable, Writable]]) + : RDD[Writable] = { + val initializeJobConfFunc = HadoopTableReader.initializeLocalJobConfFunc(path, tableDesc) _ + + val rdd = new HadoopRDD( + sc.sparkContext, + _broadcastedHiveConf.asInstanceOf[Broadcast[SerializableWritable[Configuration]]], + Some(initializeJobConfFunc), + inputFormatClass, + classOf[Writable], + classOf[Writable], + _minSplitsPerRDD) + + // Only take the value (skip the key) because Hive works only with values. + rdd.map(_._2) + } + +} + +private[hive] object HadoopTableReader { + + /** + * Curried. After given an argument for 'path', the resulting JobConf => Unit closure is used to + * instantiate a HadoopRDD. + */ + def initializeLocalJobConfFunc(path: String, tableDesc: TableDesc)(jobConf: JobConf) { + FileInputFormat.setInputPaths(jobConf, path) + if (tableDesc != null) { + Utilities.copyTableJobPropertiesToConf(tableDesc, jobConf) + } + val bufferSize = System.getProperty("spark.buffer.size", "65536") + jobConf.set("io.file.buffer.size", bufferSize) + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala new file mode 100644 index 000000000..17ae4ef63 --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala @@ -0,0 +1,341 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package hive + +import java.io.File +import java.util.{Set => JavaSet} + +import scala.collection.mutable +import scala.collection.JavaConversions._ +import scala.language.implicitConversions + +import org.apache.hadoop.hive.metastore.api.{SerDeInfo, StorageDescriptor} +import org.apache.hadoop.hive.metastore.MetaStoreUtils +import org.apache.hadoop.hive.ql.exec.FunctionRegistry +import org.apache.hadoop.hive.ql.io.avro.{AvroContainerOutputFormat, AvroContainerInputFormat} +import org.apache.hadoop.hive.ql.metadata.Table +import org.apache.hadoop.hive.serde2.avro.AvroSerDe +import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe +import org.apache.hadoop.hive.serde2.RegexSerDe + +import org.apache.spark.{SparkContext, SparkConf} + +import catalyst.analysis._ +import catalyst.plans.logical.{LogicalPlan, NativeCommand} +import catalyst.util._ + +object TestHive + extends TestHiveContext(new SparkContext("local", "TestSQLContext", new SparkConf())) + +/** + * A locally running test instance of Spark's Hive execution engine. + * + * Data from [[testTables]] will be automatically loaded whenever a query is run over those tables. + * Calling [[reset]] will delete all tables and other state in the database, leaving the database + * in a "clean" state. + * + * TestHive is singleton object version of this class because instantiating multiple copies of the + * hive metastore seems to lead to weird non-deterministic failures. Therefore, the execution of + * testcases that rely on TestHive must be serialized. + */ +class TestHiveContext(sc: SparkContext) extends LocalHiveContext(sc) { + self => + + // By clearing the port we force Spark to pick a new one. This allows us to rerun tests + // without restarting the JVM. + System.clearProperty("spark.driver.port") + System.clearProperty("spark.hostPort") + + override lazy val warehousePath = getTempFilePath("sparkHiveWarehouse").getCanonicalPath + override lazy val metastorePath = getTempFilePath("sparkHiveMetastore").getCanonicalPath + + /** The location of the compiled hive distribution */ + lazy val hiveHome = envVarToFile("HIVE_HOME") + /** The location of the hive source code. */ + lazy val hiveDevHome = envVarToFile("HIVE_DEV_HOME") + + // Override so we can intercept relative paths and rewrite them to point at hive. + override def runSqlHive(sql: String): Seq[String] = super.runSqlHive(rewritePaths(sql)) + + override def executePlan(plan: LogicalPlan): this.QueryExecution = + new this.QueryExecution { val logical = plan } + + /** + * Returns the value of specified environmental variable as a [[java.io.File]] after checking + * to ensure it exists + */ + private def envVarToFile(envVar: String): Option[File] = { + Option(System.getenv(envVar)).map(new File(_)) + } + + /** + * Replaces relative paths to the parent directory "../" with hiveDevHome since this is how the + * hive test cases assume the system is set up. + */ + private def rewritePaths(cmd: String): String = + if (cmd.toUpperCase contains "LOAD DATA") { + val testDataLocation = + hiveDevHome.map(_.getCanonicalPath).getOrElse(inRepoTests.getCanonicalPath) + cmd.replaceAll("\\.\\.", testDataLocation) + } else { + cmd + } + + val hiveFilesTemp = File.createTempFile("catalystHiveFiles", "") + hiveFilesTemp.delete() + hiveFilesTemp.mkdir() + + val inRepoTests = new File("src/test/resources/") + def getHiveFile(path: String): File = { + val stripped = path.replaceAll("""\.\.\/""", "") + hiveDevHome + .map(new File(_, stripped)) + .filter(_.exists) + .getOrElse(new File(inRepoTests, stripped)) + } + + val describedTable = "DESCRIBE (\\w+)".r + + class SqlQueryExecution(sql: String) extends this.QueryExecution { + lazy val logical = HiveQl.parseSql(sql) + def hiveExec() = runSqlHive(sql) + override def toString = sql + "\n" + super.toString + } + + /** + * Override QueryExecution with special debug workflow. + */ + abstract class QueryExecution extends super.QueryExecution { + override lazy val analyzed = { + val describedTables = logical match { + case NativeCommand(describedTable(tbl)) => tbl :: Nil + case _ => Nil + } + + // Make sure any test tables referenced are loaded. + val referencedTables = + describedTables ++ + logical.collect { case UnresolvedRelation(databaseName, name, _) => name } + val referencedTestTables = referencedTables.filter(testTables.contains) + logger.debug(s"Query references test tables: ${referencedTestTables.mkString(", ")}") + referencedTestTables.foreach(loadTestTable) + // Proceed with analysis. + analyzer(logical) + } + } + + case class TestTable(name: String, commands: (()=>Unit)*) + + implicit class SqlCmd(sql: String) { + def cmd = () => new SqlQueryExecution(sql).stringResult(): Unit + } + + /** + * A list of test tables and the DDL required to initialize them. A test table is loaded on + * demand when a query are run against it. + */ + lazy val testTables = new mutable.HashMap[String, TestTable]() + def registerTestTable(testTable: TestTable) = testTables += (testTable.name -> testTable) + + // The test tables that are defined in the Hive QTestUtil. + // /itests/util/src/main/java/org/apache/hadoop/hive/ql/QTestUtil.java + val hiveQTestUtilTables = Seq( + TestTable("src", + "CREATE TABLE src (key INT, value STRING)".cmd, + s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv1.txt")}' INTO TABLE src".cmd), + TestTable("src1", + "CREATE TABLE src1 (key INT, value STRING)".cmd, + s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv3.txt")}' INTO TABLE src1".cmd), + TestTable("dest1", + "CREATE TABLE IF NOT EXISTS dest1 (key INT, value STRING)".cmd), + TestTable("dest2", + "CREATE TABLE IF NOT EXISTS dest2 (key INT, value STRING)".cmd), + TestTable("dest3", + "CREATE TABLE IF NOT EXISTS dest3 (key INT, value STRING)".cmd), + TestTable("srcpart", () => { + runSqlHive( + "CREATE TABLE srcpart (key INT, value STRING) PARTITIONED BY (ds STRING, hr STRING)") + for (ds <- Seq("2008-04-08", "2008-04-09"); hr <- Seq("11", "12")) { + runSqlHive( + s"""LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv1.txt")}' + |OVERWRITE INTO TABLE srcpart PARTITION (ds='$ds',hr='$hr') + """.stripMargin) + } + }), + TestTable("srcpart1", () => { + runSqlHive("CREATE TABLE srcpart1 (key INT, value STRING) PARTITIONED BY (ds STRING, hr INT)") + for (ds <- Seq("2008-04-08", "2008-04-09"); hr <- 11 to 12) { + runSqlHive( + s"""LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv1.txt")}' + |OVERWRITE INTO TABLE srcpart1 PARTITION (ds='$ds',hr='$hr') + """.stripMargin) + } + }), + TestTable("src_thrift", () => { + import org.apache.thrift.protocol.TBinaryProtocol + import org.apache.hadoop.hive.serde2.thrift.test.Complex + import org.apache.hadoop.hive.serde2.thrift.ThriftDeserializer + import org.apache.hadoop.mapred.SequenceFileInputFormat + import org.apache.hadoop.mapred.SequenceFileOutputFormat + + val srcThrift = new Table("default", "src_thrift") + srcThrift.setFields(Nil) + srcThrift.setInputFormatClass(classOf[SequenceFileInputFormat[_,_]].getName) + // In Hive, SequenceFileOutputFormat will be substituted by HiveSequenceFileOutputFormat. + srcThrift.setOutputFormatClass(classOf[SequenceFileOutputFormat[_,_]].getName) + srcThrift.setSerializationLib(classOf[ThriftDeserializer].getName) + srcThrift.setSerdeParam("serialization.class", classOf[Complex].getName) + srcThrift.setSerdeParam("serialization.format", classOf[TBinaryProtocol].getName) + catalog.client.createTable(srcThrift) + + + runSqlHive( + s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/complex.seq")}' INTO TABLE src_thrift") + }), + TestTable("serdeins", + s"""CREATE TABLE serdeins (key INT, value STRING) + |ROW FORMAT SERDE '${classOf[LazySimpleSerDe].getCanonicalName}' + |WITH SERDEPROPERTIES ('field.delim'='\\t') + """.stripMargin.cmd, + "INSERT OVERWRITE TABLE serdeins SELECT * FROM src".cmd), + TestTable("sales", + s"""CREATE TABLE IF NOT EXISTS sales (key STRING, value INT) + |ROW FORMAT SERDE '${classOf[RegexSerDe].getCanonicalName}' + |WITH SERDEPROPERTIES ("input.regex" = "([^ ]*)\t([^ ]*)") + """.stripMargin.cmd, + s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/sales.txt")}' INTO TABLE sales".cmd), + TestTable("episodes", + s"""CREATE TABLE episodes (title STRING, air_date STRING, doctor INT) + |ROW FORMAT SERDE '${classOf[AvroSerDe].getCanonicalName}' + |STORED AS + |INPUTFORMAT '${classOf[AvroContainerInputFormat].getCanonicalName}' + |OUTPUTFORMAT '${classOf[AvroContainerOutputFormat].getCanonicalName}' + |TBLPROPERTIES ( + | 'avro.schema.literal'='{ + | "type": "record", + | "name": "episodes", + | "namespace": "testing.hive.avro.serde", + | "fields": [ + | { + | "name": "title", + | "type": "string", + | "doc": "episode title" + | }, + | { + | "name": "air_date", + | "type": "string", + | "doc": "initial date" + | }, + | { + | "name": "doctor", + | "type": "int", + | "doc": "main actor playing the Doctor in episode" + | } + | ] + | }' + |) + """.stripMargin.cmd, + s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/episodes.avro")}' INTO TABLE episodes".cmd + ) + ) + + hiveQTestUtilTables.foreach(registerTestTable) + + private val loadedTables = new collection.mutable.HashSet[String] + + def loadTestTable(name: String) { + if (!(loadedTables contains name)) { + // Marks the table as loaded first to prevent infite mutually recursive table loading. + loadedTables += name + logger.info(s"Loading test table $name") + val createCmds = + testTables.get(name).map(_.commands).getOrElse(sys.error(s"Unknown test table $name")) + createCmds.foreach(_()) + } + } + + /** + * Records the UDFs present when the server starts, so we can delete ones that are created by + * tests. + */ + protected val originalUdfs: JavaSet[String] = FunctionRegistry.getFunctionNames + + /** + * Resets the test instance by deleting any tables that have been created. + * TODO: also clear out UDFs, views, etc. + */ + def reset() { + try { + // HACK: Hive is too noisy by default. + org.apache.log4j.LogManager.getCurrentLoggers.foreach { logger => + logger.asInstanceOf[org.apache.log4j.Logger].setLevel(org.apache.log4j.Level.WARN) + } + + // It is important that we RESET first as broken hooks that might have been set could break + // other sql exec here. + runSqlHive("RESET") + // For some reason, RESET does not reset the following variables... + runSqlHive("set datanucleus.cache.collections=true") + runSqlHive("set datanucleus.cache.collections.lazy=true") + // Lots of tests fail if we do not change the partition whitelist from the default. + runSqlHive("set hive.metastore.partition.name.whitelist.pattern=.*") + + loadedTables.clear() + catalog.client.getAllTables("default").foreach { t => + logger.debug(s"Deleting table $t") + val table = catalog.client.getTable("default", t) + + catalog.client.getIndexes("default", t, 255).foreach { index => + catalog.client.dropIndex("default", t, index.getIndexName, true) + } + + if (!table.isIndexTable) { + catalog.client.dropTable("default", t) + } + } + + catalog.client.getAllDatabases.filterNot(_ == "default").foreach { db => + logger.debug(s"Dropping Database: $db") + catalog.client.dropDatabase(db, true, false, true) + } + + FunctionRegistry.getFunctionNames.filterNot(originalUdfs.contains(_)).foreach { udfName => + FunctionRegistry.unregisterTemporaryUDF(udfName) + } + + configure() + + runSqlHive("USE default") + + // Just loading src makes a lot of tests pass. This is because some tests do something like + // drop an index on src at the beginning. Since we just pass DDL to hive this bypasses our + // Analyzer and thus the test table auto-loading mechanism. + // Remove after we handle more DDL operations natively. + loadTestTable("src") + loadTestTable("srcpart") + } catch { + case e: Exception => + logger.error(s"FATAL ERROR: Failed to reset TestDB state. $e") + // At this point there is really no reason to continue, but the test framework traps exits. + // So instead we just pause forever so that at least the developer can see where things + // started to go wrong. + Thread.sleep(100000) + } + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveOperators.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveOperators.scala new file mode 100644 index 000000000..d20fd87f3 --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveOperators.scala @@ -0,0 +1,356 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package hive + +import org.apache.hadoop.hive.common.`type`.{HiveDecimal, HiveVarchar} +import org.apache.hadoop.hive.metastore.MetaStoreUtils +import org.apache.hadoop.hive.ql.Context +import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition, Hive} +import org.apache.hadoop.hive.ql.plan.{TableDesc, FileSinkDesc} +import org.apache.hadoop.hive.serde2.Serializer +import org.apache.hadoop.hive.serde2.objectinspector._ +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption +import org.apache.hadoop.hive.serde2.objectinspector.primitive.JavaHiveDecimalObjectInspector +import org.apache.hadoop.hive.serde2.objectinspector.primitive.JavaHiveVarcharObjectInspector + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.io.Writable +import org.apache.hadoop.mapred._ + +import catalyst.expressions._ +import catalyst.types.{BooleanType, DataType} +import org.apache.spark.{TaskContext, SparkException} +import catalyst.expressions.Cast +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.execution._ + +import scala.Some +import scala.collection.immutable.ListMap + +/* Implicits */ +import scala.collection.JavaConversions._ + +/** + * The Hive table scan operator. Column and partition pruning are both handled. + * + * @constructor + * @param attributes Attributes to be fetched from the Hive table. + * @param relation The Hive table be be scanned. + * @param partitionPruningPred An optional partition pruning predicate for partitioned table. + */ +case class HiveTableScan( + attributes: Seq[Attribute], + relation: MetastoreRelation, + partitionPruningPred: Option[Expression])( + @transient val sc: HiveContext) + extends LeafNode + with HiveInspectors { + + require(partitionPruningPred.isEmpty || relation.hiveQlTable.isPartitioned, + "Partition pruning predicates only supported for partitioned tables.") + + // Bind all partition key attribute references in the partition pruning predicate for later + // evaluation. + private val boundPruningPred = partitionPruningPred.map { pred => + require( + pred.dataType == BooleanType, + s"Data type of predicate $pred must be BooleanType rather than ${pred.dataType}.") + + BindReferences.bindReference(pred, relation.partitionKeys) + } + + @transient + val hadoopReader = new HadoopTableReader(relation.tableDesc, sc) + + /** + * The hive object inspector for this table, which can be used to extract values from the + * serialized row representation. + */ + @transient + lazy val objectInspector = + relation.tableDesc.getDeserializer.getObjectInspector.asInstanceOf[StructObjectInspector] + + /** + * Functions that extract the requested attributes from the hive output. Partitioned values are + * casted from string to its declared data type. + */ + @transient + protected lazy val attributeFunctions: Seq[(Any, Array[String]) => Any] = { + attributes.map { a => + val ordinal = relation.partitionKeys.indexOf(a) + if (ordinal >= 0) { + (_: Any, partitionKeys: Array[String]) => { + val value = partitionKeys(ordinal) + val dataType = relation.partitionKeys(ordinal).dataType + castFromString(value, dataType) + } + } else { + val ref = objectInspector.getAllStructFieldRefs + .find(_.getFieldName == a.name) + .getOrElse(sys.error(s"Can't find attribute $a")) + (row: Any, _: Array[String]) => { + val data = objectInspector.getStructFieldData(row, ref) + unwrapData(data, ref.getFieldObjectInspector) + } + } + } + } + + private def castFromString(value: String, dataType: DataType) = { + Cast(Literal(value), dataType).apply(null) + } + + @transient + def inputRdd = if (!relation.hiveQlTable.isPartitioned) { + hadoopReader.makeRDDForTable(relation.hiveQlTable) + } else { + hadoopReader.makeRDDForPartitionedTable(prunePartitions(relation.hiveQlPartitions)) + } + + /** + * Prunes partitions not involve the query plan. + * + * @param partitions All partitions of the relation. + * @return Partitions that are involved in the query plan. + */ + private[hive] def prunePartitions(partitions: Seq[HivePartition]) = { + boundPruningPred match { + case None => partitions + case Some(shouldKeep) => partitions.filter { part => + val dataTypes = relation.partitionKeys.map(_.dataType) + val castedValues = for ((value, dataType) <- part.getValues.zip(dataTypes)) yield { + castFromString(value, dataType) + } + + // Only partitioned values are needed here, since the predicate has already been bound to + // partition key attribute references. + val row = new GenericRow(castedValues.toArray) + shouldKeep.apply(row).asInstanceOf[Boolean] + } + } + } + + def execute() = { + inputRdd.map { row => + val values = row match { + case Array(deserializedRow: AnyRef, partitionKeys: Array[String]) => + attributeFunctions.map(_(deserializedRow, partitionKeys)) + case deserializedRow: AnyRef => + attributeFunctions.map(_(deserializedRow, Array.empty)) + } + buildRow(values.map { + case n: String if n.toLowerCase == "null" => null + case varchar: org.apache.hadoop.hive.common.`type`.HiveVarchar => varchar.getValue + case decimal: org.apache.hadoop.hive.common.`type`.HiveDecimal => + BigDecimal(decimal.bigDecimalValue) + case other => other + }) + } + } + + def output = attributes +} + +case class InsertIntoHiveTable( + table: MetastoreRelation, + partition: Map[String, Option[String]], + child: SparkPlan, + overwrite: Boolean) + (@transient sc: HiveContext) + extends UnaryNode { + + val outputClass = newSerializer(table.tableDesc).getSerializedClass + @transient private val hiveContext = new Context(sc.hiveconf) + @transient private val db = Hive.get(sc.hiveconf) + + private def newSerializer(tableDesc: TableDesc): Serializer = { + val serializer = tableDesc.getDeserializerClass.newInstance().asInstanceOf[Serializer] + serializer.initialize(null, tableDesc.getProperties) + serializer + } + + override def otherCopyArgs = sc :: Nil + + def output = child.output + + /** + * Wraps with Hive types based on object inspector. + * TODO: Consolidate all hive OI/data interface code. + */ + protected def wrap(a: (Any, ObjectInspector)): Any = a match { + case (s: String, oi: JavaHiveVarcharObjectInspector) => new HiveVarchar(s, s.size) + case (bd: BigDecimal, oi: JavaHiveDecimalObjectInspector) => + new HiveDecimal(bd.underlying()) + case (row: Row, oi: StandardStructObjectInspector) => + val struct = oi.create() + row.zip(oi.getAllStructFieldRefs).foreach { + case (data, field) => + oi.setStructFieldData(struct, field, wrap(data, field.getFieldObjectInspector)) + } + struct + case (s: Seq[_], oi: ListObjectInspector) => + val wrappedSeq = s.map(wrap(_, oi.getListElementObjectInspector)) + seqAsJavaList(wrappedSeq) + case (obj, _) => obj + } + + def saveAsHiveFile( + rdd: RDD[Writable], + valueClass: Class[_], + fileSinkConf: FileSinkDesc, + conf: JobConf, + isCompressed: Boolean) { + if (valueClass == null) { + throw new SparkException("Output value class not set") + } + conf.setOutputValueClass(valueClass) + if (fileSinkConf.getTableInfo.getOutputFileFormatClassName == null) { + throw new SparkException("Output format class not set") + } + // Doesn't work in Scala 2.9 due to what may be a generics bug + // TODO: Should we uncomment this for Scala 2.10? + // conf.setOutputFormat(outputFormatClass) + conf.set("mapred.output.format.class", fileSinkConf.getTableInfo.getOutputFileFormatClassName) + if (isCompressed) { + // Please note that isCompressed, "mapred.output.compress", "mapred.output.compression.codec", + // and "mapred.output.compression.type" have no impact on ORC because it uses table properties + // to store compression information. + conf.set("mapred.output.compress", "true") + fileSinkConf.setCompressed(true) + fileSinkConf.setCompressCodec(conf.get("mapred.output.compression.codec")) + fileSinkConf.setCompressType(conf.get("mapred.output.compression.type")) + } + conf.setOutputCommitter(classOf[FileOutputCommitter]) + FileOutputFormat.setOutputPath( + conf, + SparkHiveHadoopWriter.createPathFromString(fileSinkConf.getDirName, conf)) + + logger.debug("Saving as hadoop file of type " + valueClass.getSimpleName) + + val writer = new SparkHiveHadoopWriter(conf, fileSinkConf) + writer.preSetup() + + def writeToFile(context: TaskContext, iter: Iterator[Writable]) { + // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it + // around by taking a mod. We expect that no task will be attempted 2 billion times. + val attemptNumber = (context.attemptId % Int.MaxValue).toInt + + writer.setup(context.stageId, context.partitionId, attemptNumber) + writer.open() + + var count = 0 + while(iter.hasNext) { + val record = iter.next() + count += 1 + writer.write(record) + } + + writer.close() + writer.commit() + } + + sc.sparkContext.runJob(rdd, writeToFile _) + writer.commitJob() + } + + /** + * Inserts all the rows in the table into Hive. Row objects are properly serialized with the + * `org.apache.hadoop.hive.serde2.SerDe` and the + * `org.apache.hadoop.mapred.OutputFormat` provided by the table definition. + */ + def execute() = { + val childRdd = child.execute() + assert(childRdd != null) + + // Have to pass the TableDesc object to RDD.mapPartitions and then instantiate new serializer + // instances within the closure, since Serializer is not serializable while TableDesc is. + val tableDesc = table.tableDesc + val tableLocation = table.hiveQlTable.getDataLocation + val tmpLocation = hiveContext.getExternalTmpFileURI(tableLocation) + val fileSinkConf = new FileSinkDesc(tmpLocation.toString, tableDesc, false) + val rdd = childRdd.mapPartitions { iter => + val serializer = newSerializer(fileSinkConf.getTableInfo) + val standardOI = ObjectInspectorUtils + .getStandardObjectInspector( + fileSinkConf.getTableInfo.getDeserializer.getObjectInspector, + ObjectInspectorCopyOption.JAVA) + .asInstanceOf[StructObjectInspector] + + iter.map { row => + // Casts Strings to HiveVarchars when necessary. + val fieldOIs = standardOI.getAllStructFieldRefs.map(_.getFieldObjectInspector) + val mappedRow = row.zip(fieldOIs).map(wrap) + + serializer.serialize(mappedRow.toArray, standardOI) + } + } + + // ORC stores compression information in table properties. While, there are other formats + // (e.g. RCFile) that rely on hadoop configurations to store compression information. + val jobConf = new JobConf(sc.hiveconf) + saveAsHiveFile( + rdd, + outputClass, + fileSinkConf, + jobConf, + sc.hiveconf.getBoolean("hive.exec.compress.output", false)) + + // TODO: Handle dynamic partitioning. + val outputPath = FileOutputFormat.getOutputPath(jobConf) + // Have to construct the format of dbname.tablename. + val qualifiedTableName = s"${table.databaseName}.${table.tableName}" + // TODO: Correctly set holdDDLTime. + // In most of the time, we should have holdDDLTime = false. + // holdDDLTime will be true when TOK_HOLD_DDLTIME presents in the query as a hint. + val holdDDLTime = false + if (partition.nonEmpty) { + val partitionSpec = partition.map { + case (key, Some(value)) => key -> value + case (key, None) => key -> "" // Should not reach here right now. + } + val partVals = MetaStoreUtils.getPvals(table.hiveQlTable.getPartCols(), partitionSpec) + db.validatePartitionNameCharacters(partVals) + // inheritTableSpecs is set to true. It should be set to false for a IMPORT query + // which is currently considered as a Hive native command. + val inheritTableSpecs = true + // TODO: Correctly set isSkewedStoreAsSubdir. + val isSkewedStoreAsSubdir = false + db.loadPartition( + outputPath, + qualifiedTableName, + partitionSpec, + overwrite, + holdDDLTime, + inheritTableSpecs, + isSkewedStoreAsSubdir) + } else { + db.loadTable( + outputPath, + qualifiedTableName, + overwrite, + holdDDLTime) + } + + // It would be nice to just return the childRdd unchanged so insert operations could be chained, + // however for now we return an empty list to simplify compatibility checks with hive, which + // does not return anything for insert operations. + // TODO: implement hive compatibility as rules. + sc.sparkContext.makeRDD(Nil, 1) + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala new file mode 100644 index 000000000..5e775d6a0 --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -0,0 +1,467 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package hive + +import scala.collection.JavaConversions._ +import scala.collection.mutable.ArrayBuffer + +import org.apache.hadoop.hive.common.`type`.HiveDecimal +import org.apache.hadoop.hive.serde2.{io => hiveIo} +import org.apache.hadoop.hive.serde2.objectinspector.primitive._ +import org.apache.hadoop.hive.serde2.objectinspector._ +import org.apache.hadoop.hive.ql.exec.{FunctionInfo, FunctionRegistry} +import org.apache.hadoop.hive.ql.udf.generic._ +import org.apache.hadoop.hive.ql.exec.UDF +import org.apache.hadoop.{io => hadoopIo} + +import catalyst.analysis +import catalyst.expressions._ +import catalyst.types +import catalyst.types._ + +object HiveFunctionRegistry + extends analysis.FunctionRegistry with HiveFunctionFactory with HiveInspectors { + + def lookupFunction(name: String, children: Seq[Expression]): Expression = { + // We only look it up to see if it exists, but do not include it in the HiveUDF since it is + // not always serializable. + val functionInfo: FunctionInfo = Option(FunctionRegistry.getFunctionInfo(name)).getOrElse( + sys.error(s"Couldn't find function $name")) + + if (classOf[UDF].isAssignableFrom(functionInfo.getFunctionClass)) { + val function = createFunction[UDF](name) + val method = function.getResolver.getEvalMethod(children.map(_.dataType.toTypeInfo)) + + lazy val expectedDataTypes = method.getParameterTypes.map(javaClassToDataType) + + HiveSimpleUdf( + name, + children.zip(expectedDataTypes).map { case (e, t) => Cast(e, t) } + ) + } else if (classOf[GenericUDF].isAssignableFrom(functionInfo.getFunctionClass)) { + HiveGenericUdf(name, children) + } else if ( + classOf[AbstractGenericUDAFResolver].isAssignableFrom(functionInfo.getFunctionClass)) { + HiveGenericUdaf(name, children) + + } else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) { + HiveGenericUdtf(name, Nil, children) + } else { + sys.error(s"No handler for udf ${functionInfo.getFunctionClass}") + } + } + + def javaClassToDataType(clz: Class[_]): DataType = clz match { + case c: Class[_] if c == classOf[hadoopIo.DoubleWritable] => DoubleType + case c: Class[_] if c == classOf[hiveIo.DoubleWritable] => DoubleType + case c: Class[_] if c == classOf[hiveIo.HiveDecimalWritable] => DecimalType + case c: Class[_] if c == classOf[hiveIo.ByteWritable] => ByteType + case c: Class[_] if c == classOf[hiveIo.ShortWritable] => ShortType + case c: Class[_] if c == classOf[hadoopIo.Text] => StringType + case c: Class[_] if c == classOf[hadoopIo.IntWritable] => IntegerType + case c: Class[_] if c == classOf[hadoopIo.LongWritable] => LongType + case c: Class[_] if c == classOf[hadoopIo.FloatWritable] => FloatType + case c: Class[_] if c == classOf[hadoopIo.BooleanWritable] => BooleanType + case c: Class[_] if c == classOf[java.lang.String] => StringType + case c: Class[_] if c == java.lang.Short.TYPE => ShortType + case c: Class[_] if c == java.lang.Integer.TYPE => IntegerType + case c: Class[_] if c == java.lang.Long.TYPE => LongType + case c: Class[_] if c == java.lang.Double.TYPE => DoubleType + case c: Class[_] if c == java.lang.Byte.TYPE => ByteType + case c: Class[_] if c == java.lang.Float.TYPE => FloatType + case c: Class[_] if c == java.lang.Boolean.TYPE => BooleanType + case c: Class[_] if c == classOf[java.lang.Short] => ShortType + case c: Class[_] if c == classOf[java.lang.Integer] => IntegerType + case c: Class[_] if c == classOf[java.lang.Long] => LongType + case c: Class[_] if c == classOf[java.lang.Double] => DoubleType + case c: Class[_] if c == classOf[java.lang.Byte] => ByteType + case c: Class[_] if c == classOf[java.lang.Float] => FloatType + case c: Class[_] if c == classOf[java.lang.Boolean] => BooleanType + case c: Class[_] if c.isArray => ArrayType(javaClassToDataType(c.getComponentType)) + } +} + +trait HiveFunctionFactory { + def getFunctionInfo(name: String) = FunctionRegistry.getFunctionInfo(name) + def getFunctionClass(name: String) = getFunctionInfo(name).getFunctionClass + def createFunction[UDFType](name: String) = + getFunctionClass(name).newInstance.asInstanceOf[UDFType] + + /** Converts hive types to native catalyst types. */ + def unwrap(a: Any): Any = a match { + case null => null + case i: hadoopIo.IntWritable => i.get + case t: hadoopIo.Text => t.toString + case l: hadoopIo.LongWritable => l.get + case d: hadoopIo.DoubleWritable => d.get() + case d: hiveIo.DoubleWritable => d.get + case s: hiveIo.ShortWritable => s.get + case b: hadoopIo.BooleanWritable => b.get() + case b: hiveIo.ByteWritable => b.get + case list: java.util.List[_] => list.map(unwrap) + case map: java.util.Map[_,_] => map.map { case (k, v) => (unwrap(k), unwrap(v)) }.toMap + case array: Array[_] => array.map(unwrap).toSeq + case p: java.lang.Short => p + case p: java.lang.Long => p + case p: java.lang.Float => p + case p: java.lang.Integer => p + case p: java.lang.Double => p + case p: java.lang.Byte => p + case p: java.lang.Boolean => p + case str: String => str + } +} + +abstract class HiveUdf + extends Expression with Logging with HiveFunctionFactory { + self: Product => + + type UDFType + type EvaluatedType = Any + + val name: String + + def nullable = true + def references = children.flatMap(_.references).toSet + + // FunctionInfo is not serializable so we must look it up here again. + lazy val functionInfo = getFunctionInfo(name) + lazy val function = createFunction[UDFType](name) + + override def toString = s"${nodeName}#${functionInfo.getDisplayName}(${children.mkString(",")})" +} + +case class HiveSimpleUdf(name: String, children: Seq[Expression]) extends HiveUdf { + import HiveFunctionRegistry._ + type UDFType = UDF + + @transient + protected lazy val method = + function.getResolver.getEvalMethod(children.map(_.dataType.toTypeInfo)) + + @transient + lazy val dataType = javaClassToDataType(method.getReturnType) + + protected lazy val wrappers: Array[(Any) => AnyRef] = method.getParameterTypes.map { argClass => + val primitiveClasses = Seq( + Integer.TYPE, classOf[java.lang.Integer], classOf[java.lang.String], java.lang.Double.TYPE, + classOf[java.lang.Double], java.lang.Long.TYPE, classOf[java.lang.Long], + classOf[HiveDecimal], java.lang.Byte.TYPE, classOf[java.lang.Byte] + ) + val matchingConstructor = argClass.getConstructors.find { c => + c.getParameterTypes.size == 1 && primitiveClasses.contains(c.getParameterTypes.head) + } + + val constructor = matchingConstructor.getOrElse( + sys.error(s"No matching wrapper found, options: ${argClass.getConstructors.toSeq}.")) + + (a: Any) => { + logger.debug( + s"Wrapping $a of type ${if (a == null) "null" else a.getClass.getName} using $constructor.") + // We must make sure that primitives get boxed java style. + if (a == null) { + null + } else { + constructor.newInstance(a match { + case i: Int => i: java.lang.Integer + case bd: BigDecimal => new HiveDecimal(bd.underlying()) + case other: AnyRef => other + }).asInstanceOf[AnyRef] + } + } + } + + // TODO: Finish input output types. + override def apply(input: Row): Any = { + val evaluatedChildren = children.map(_.apply(input)) + // Wrap the function arguments in the expected types. + val args = evaluatedChildren.zip(wrappers).map { + case (arg, wrapper) => wrapper(arg) + } + + // Invoke the udf and unwrap the result. + unwrap(method.invoke(function, args: _*)) + } +} + +case class HiveGenericUdf( + name: String, + children: Seq[Expression]) extends HiveUdf with HiveInspectors { + import org.apache.hadoop.hive.ql.udf.generic.GenericUDF._ + type UDFType = GenericUDF + + @transient + protected lazy val argumentInspectors = children.map(_.dataType).map(toInspector) + + @transient + protected lazy val returnInspector = function.initialize(argumentInspectors.toArray) + + val dataType: DataType = inspectorToDataType(returnInspector) + + override def apply(input: Row): Any = { + returnInspector // Make sure initialized. + val args = children.map { v => + new DeferredObject { + override def prepare(i: Int) = {} + override def get(): AnyRef = wrap(v.apply(input)) + } + }.toArray + unwrap(function.evaluate(args)) + } +} + +trait HiveInspectors { + + def unwrapData(data: Any, oi: ObjectInspector): Any = oi match { + case pi: PrimitiveObjectInspector => pi.getPrimitiveJavaObject(data) + case li: ListObjectInspector => + Option(li.getList(data)) + .map(_.map(unwrapData(_, li.getListElementObjectInspector)).toSeq) + .orNull + case mi: MapObjectInspector => + Option(mi.getMap(data)).map( + _.map { + case (k,v) => + (unwrapData(k, mi.getMapKeyObjectInspector), + unwrapData(v, mi.getMapValueObjectInspector)) + }.toMap).orNull + case si: StructObjectInspector => + val allRefs = si.getAllStructFieldRefs + new GenericRow( + allRefs.map(r => + unwrapData(si.getStructFieldData(data,r), r.getFieldObjectInspector)).toArray) + } + + /** Converts native catalyst types to the types expected by Hive */ + def wrap(a: Any): AnyRef = a match { + case s: String => new hadoopIo.Text(s) + case i: Int => i: java.lang.Integer + case b: Boolean => b: java.lang.Boolean + case d: Double => d: java.lang.Double + case l: Long => l: java.lang.Long + case l: Short => l: java.lang.Short + case l: Byte => l: java.lang.Byte + case s: Seq[_] => seqAsJavaList(s.map(wrap)) + case m: Map[_,_] => + mapAsJavaMap(m.map { case (k, v) => wrap(k) -> wrap(v) }) + case null => null + } + + def toInspector(dataType: DataType): ObjectInspector = dataType match { + case ArrayType(tpe) => ObjectInspectorFactory.getStandardListObjectInspector(toInspector(tpe)) + case MapType(keyType, valueType) => + ObjectInspectorFactory.getStandardMapObjectInspector( + toInspector(keyType), toInspector(valueType)) + case StringType => PrimitiveObjectInspectorFactory.javaStringObjectInspector + case IntegerType => PrimitiveObjectInspectorFactory.javaIntObjectInspector + case DoubleType => PrimitiveObjectInspectorFactory.javaDoubleObjectInspector + case BooleanType => PrimitiveObjectInspectorFactory.javaBooleanObjectInspector + case LongType => PrimitiveObjectInspectorFactory.javaLongObjectInspector + case FloatType => PrimitiveObjectInspectorFactory.javaFloatObjectInspector + case ShortType => PrimitiveObjectInspectorFactory.javaShortObjectInspector + case ByteType => PrimitiveObjectInspectorFactory.javaByteObjectInspector + case NullType => PrimitiveObjectInspectorFactory.javaVoidObjectInspector + case BinaryType => PrimitiveObjectInspectorFactory.javaByteArrayObjectInspector + } + + def inspectorToDataType(inspector: ObjectInspector): DataType = inspector match { + case s: StructObjectInspector => + StructType(s.getAllStructFieldRefs.map(f => { + types.StructField( + f.getFieldName, inspectorToDataType(f.getFieldObjectInspector), nullable = true) + })) + case l: ListObjectInspector => ArrayType(inspectorToDataType(l.getListElementObjectInspector)) + case m: MapObjectInspector => + MapType( + inspectorToDataType(m.getMapKeyObjectInspector), + inspectorToDataType(m.getMapValueObjectInspector)) + case _: WritableStringObjectInspector => StringType + case _: JavaStringObjectInspector => StringType + case _: WritableIntObjectInspector => IntegerType + case _: JavaIntObjectInspector => IntegerType + case _: WritableDoubleObjectInspector => DoubleType + case _: JavaDoubleObjectInspector => DoubleType + case _: WritableBooleanObjectInspector => BooleanType + case _: JavaBooleanObjectInspector => BooleanType + case _: WritableLongObjectInspector => LongType + case _: JavaLongObjectInspector => LongType + case _: WritableShortObjectInspector => ShortType + case _: JavaShortObjectInspector => ShortType + case _: WritableByteObjectInspector => ByteType + case _: JavaByteObjectInspector => ByteType + } + + implicit class typeInfoConversions(dt: DataType) { + import org.apache.hadoop.hive.serde2.typeinfo._ + import TypeInfoFactory._ + + def toTypeInfo: TypeInfo = dt match { + case BinaryType => binaryTypeInfo + case BooleanType => booleanTypeInfo + case ByteType => byteTypeInfo + case DoubleType => doubleTypeInfo + case FloatType => floatTypeInfo + case IntegerType => intTypeInfo + case LongType => longTypeInfo + case ShortType => shortTypeInfo + case StringType => stringTypeInfo + case DecimalType => decimalTypeInfo + case NullType => voidTypeInfo + } + } +} + +case class HiveGenericUdaf( + name: String, + children: Seq[Expression]) extends AggregateExpression + with HiveInspectors + with HiveFunctionFactory { + + type UDFType = AbstractGenericUDAFResolver + + protected lazy val resolver: AbstractGenericUDAFResolver = createFunction(name) + + protected lazy val objectInspector = { + resolver.getEvaluator(children.map(_.dataType.toTypeInfo).toArray) + .init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors.toArray) + } + + protected lazy val inspectors = children.map(_.dataType).map(toInspector) + + def dataType: DataType = inspectorToDataType(objectInspector) + + def nullable: Boolean = true + + def references: Set[Attribute] = children.map(_.references).flatten.toSet + + override def toString = s"$nodeName#$name(${children.mkString(",")})" + + def newInstance = new HiveUdafFunction(name, children, this) +} + +/** + * Converts a Hive Generic User Defined Table Generating Function (UDTF) to a + * [[catalyst.expressions.Generator Generator]]. Note that the semantics of Generators do not allow + * Generators to maintain state in between input rows. Thus UDTFs that rely on partitioning + * dependent operations like calls to `close()` before producing output will not operate the same as + * in Hive. However, in practice this should not affect compatibility for most sane UDTFs + * (e.g. explode or GenericUDTFParseUrlTuple). + * + * Operators that require maintaining state in between input rows should instead be implemented as + * user defined aggregations, which have clean semantics even in a partitioned execution. + */ +case class HiveGenericUdtf( + name: String, + aliasNames: Seq[String], + children: Seq[Expression]) + extends Generator with HiveInspectors with HiveFunctionFactory { + + override def references = children.flatMap(_.references).toSet + + @transient + protected lazy val function: GenericUDTF = createFunction(name) + + protected lazy val inputInspectors = children.map(_.dataType).map(toInspector) + + protected lazy val outputInspectors = { + val structInspector = function.initialize(inputInspectors.toArray) + structInspector.getAllStructFieldRefs.map(_.getFieldObjectInspector) + } + + protected lazy val outputDataTypes = outputInspectors.map(inspectorToDataType) + + override protected def makeOutput() = { + // Use column names when given, otherwise c_1, c_2, ... c_n. + if (aliasNames.size == outputDataTypes.size) { + aliasNames.zip(outputDataTypes).map { + case (attrName, attrDataType) => + AttributeReference(attrName, attrDataType, nullable = true)() + } + } else { + outputDataTypes.zipWithIndex.map { + case (attrDataType, i) => + AttributeReference(s"c_$i", attrDataType, nullable = true)() + } + } + } + + override def apply(input: Row): TraversableOnce[Row] = { + outputInspectors // Make sure initialized. + + val inputProjection = new Projection(children) + val collector = new UDTFCollector + function.setCollector(collector) + + val udtInput = inputProjection(input).map(wrap).toArray + function.process(udtInput) + collector.collectRows() + } + + protected class UDTFCollector extends Collector { + var collected = new ArrayBuffer[Row] + + override def collect(input: java.lang.Object) { + // We need to clone the input here because implementations of + // GenericUDTF reuse the same object. Luckily they are always an array, so + // it is easy to clone. + collected += new GenericRow(input.asInstanceOf[Array[_]].map(unwrap)) + } + + def collectRows() = { + val toCollect = collected + collected = new ArrayBuffer[Row] + toCollect + } + } + + override def toString() = s"$nodeName#$name(${children.mkString(",")})" +} + +case class HiveUdafFunction( + functionName: String, + exprs: Seq[Expression], + base: AggregateExpression) + extends AggregateFunction + with HiveInspectors + with HiveFunctionFactory { + + def this() = this(null, null, null) + + private val resolver = createFunction[AbstractGenericUDAFResolver](functionName) + + private val inspectors = exprs.map(_.dataType).map(toInspector).toArray + + private val function = resolver.getEvaluator(exprs.map(_.dataType.toTypeInfo).toArray) + + private val returnInspector = function.init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors) + + // Cast required to avoid type inference selecting a deprecated Hive API. + private val buffer = + function.getNewAggregationBuffer.asInstanceOf[GenericUDAFEvaluator.AbstractAggregationBuffer] + + override def apply(input: Row): Any = unwrapData(function.evaluate(buffer), returnInspector) + + @transient + val inputProjection = new Projection(exprs) + + def update(input: Row): Unit = { + val inputs = inputProjection(input).asInstanceOf[Seq[AnyRef]].toArray + function.iterate(buffer, inputs) + } +} diff --git a/sql/hive/src/test/resources/log4j.properties b/sql/hive/src/test/resources/log4j.properties new file mode 100644 index 000000000..5e17e3b59 --- /dev/null +++ b/sql/hive/src/test/resources/log4j.properties @@ -0,0 +1,47 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Set everything to be logged to the file core/target/unit-tests.log +log4j.rootLogger=DEBUG, CA, FA + +#Console Appender +log4j.appender.CA=org.apache.log4j.ConsoleAppender +log4j.appender.CA.layout=org.apache.log4j.PatternLayout +log4j.appender.CA.layout.ConversionPattern=%d{HH:mm:ss.SSS} %p %c: %m%n +log4j.appender.CA.Threshold = WARN + + +#File Appender +log4j.appender.FA=org.apache.log4j.FileAppender +log4j.appender.FA.append=false +log4j.appender.FA.file=target/unit-tests.log +log4j.appender.FA.layout=org.apache.log4j.PatternLayout +log4j.appender.FA.layout.ConversionPattern=%d{HH:mm:ss.SSS} %p %c{1}: %m%n + +# Set the logger level of File Appender to WARN +log4j.appender.FA.Threshold = INFO + +# Some packages are noisy for no good reason. +log4j.additivity.org.apache.hadoop.hive.serde2.lazy.LazyStruct=false +log4j.logger.org.apache.hadoop.hive.serde2.lazy.LazyStruct=OFF + +log4j.additivity.org.apache.hadoop.hive.metastore.RetryingHMSHandler=false +log4j.logger.org.apache.hadoop.hive.metastore.RetryingHMSHandler=OFF + +log4j.additivity.hive.ql.metadata.Hive=false +log4j.logger.hive.ql.metadata.Hive=OFF + diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/BigDataBenchmarkSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/BigDataBenchmarkSuite.scala new file mode 100644 index 000000000..4b45e6986 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/BigDataBenchmarkSuite.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package hive +package execution + +import java.io.File + +/** + * A set of test cases based on the big-data-benchmark. + * https://amplab.cs.berkeley.edu/benchmark/ + */ +class BigDataBenchmarkSuite extends HiveComparisonTest { + import TestHive._ + + val testDataDirectory = new File("target/big-data-benchmark-testdata") + + val testTables = Seq( + TestTable( + "rankings", + s""" + |CREATE EXTERNAL TABLE rankings ( + | pageURL STRING, + | pageRank INT, + | avgDuration INT) + | ROW FORMAT DELIMITED FIELDS TERMINATED BY "," + | STORED AS TEXTFILE LOCATION "${new File(testDataDirectory, "rankings").getCanonicalPath}" + """.stripMargin.cmd), + TestTable( + "scratch", + s""" + |CREATE EXTERNAL TABLE scratch ( + | pageURL STRING, + | pageRank INT, + | avgDuration INT) + | ROW FORMAT DELIMITED FIELDS TERMINATED BY "," + | STORED AS TEXTFILE LOCATION "${new File(testDataDirectory, "scratch").getCanonicalPath}" + """.stripMargin.cmd), + TestTable( + "uservisits", + s""" + |CREATE EXTERNAL TABLE uservisits ( + | sourceIP STRING, + | destURL STRING, + | visitDate STRING, + | adRevenue DOUBLE, + | userAgent STRING, + | countryCode STRING, + | languageCode STRING, + | searchWord STRING, + | duration INT) + | ROW FORMAT DELIMITED FIELDS TERMINATED BY "," + | STORED AS TEXTFILE LOCATION "${new File(testDataDirectory, "uservisits").getCanonicalPath}" + """.stripMargin.cmd), + TestTable( + "documents", + s""" + |CREATE EXTERNAL TABLE documents (line STRING) + |STORED AS TEXTFILE + |LOCATION "${new File(testDataDirectory, "crawl").getCanonicalPath}" + """.stripMargin.cmd)) + + testTables.foreach(registerTestTable) + + if (!testDataDirectory.exists()) { + // TODO: Auto download the files on demand. + ignore("No data files found for BigDataBenchmark tests.") {} + } else { + createQueryTest("query1", + "SELECT pageURL, pageRank FROM rankings WHERE pageRank > 1") + + createQueryTest("query2", + "SELECT SUBSTR(sourceIP, 1, 10), SUM(adRevenue) FROM uservisits GROUP BY SUBSTR(sourceIP, 1, 10)") + + createQueryTest("query3", + """ + |SELECT sourceIP, + | sum(adRevenue) as totalRevenue, + | avg(pageRank) as pageRank + |FROM + | rankings R JOIN + | (SELECT sourceIP, destURL, adRevenue + | FROM uservisits UV + | WHERE UV.visitDate > "1980-01-01" + | AND UV.visitDate < "1980-04-01") + | NUV ON (R.pageURL = NUV.destURL) + |GROUP BY sourceIP + |ORDER BY totalRevenue DESC + |LIMIT 1 + """.stripMargin) + + createQueryTest("query4", + """ + |DROP TABLE IF EXISTS url_counts_partial; + |CREATE TABLE url_counts_partial AS + | SELECT TRANSFORM (line) + | USING 'python target/url_count.py' as (sourcePage, + | destPage, count) from documents; + |DROP TABLE IF EXISTS url_counts_total; + |CREATE TABLE url_counts_total AS + | SELECT SUM(count) AS totalCount, destpage + | FROM url_counts_partial GROUP BY destpage + |-- The following queries run, but generate different results in HIVE likely because the UDF is not deterministic + |-- given different input splits. + |-- SELECT CAST(SUM(count) AS INT) FROM url_counts_partial + |-- SELECT COUNT(*) FROM url_counts_partial + |-- SELECT * FROM url_counts_partial + |-- SELECT * FROM url_counts_total + """.stripMargin) + } +} \ No newline at end of file diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala new file mode 100644 index 000000000..a12ab2394 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark +package sql +package hive +package execution + + +import org.scalatest.{FunSuite, BeforeAndAfterAll} + +class ConcurrentHiveSuite extends FunSuite with BeforeAndAfterAll { + ignore("multiple instances not supported") { + test("Multiple Hive Instances") { + (1 to 10).map { i => + val ts = + new TestHiveContext(new SparkContext("local", s"TestSQLContext$i", new SparkConf())) + ts.executeSql("SHOW TABLES").toRdd.collect() + ts.executeSql("SELECT * FROM src").toRdd.collect() + ts.executeSql("SHOW TABLES").toRdd.collect() + } + } + } +} \ No newline at end of file diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala new file mode 100644 index 000000000..8a5b97b7a --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -0,0 +1,379 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package hive +package execution + +import java.io._ +import org.scalatest.{BeforeAndAfterAll, FunSuite, GivenWhenThen} + +import catalyst.plans.logical.{ExplainCommand, NativeCommand} +import catalyst.plans._ +import catalyst.util._ + +import org.apache.spark.sql.execution.Sort + +/** + * Allows the creations of tests that execute the same query against both hive + * and catalyst, comparing the results. + * + * The "golden" results from Hive are cached in an retrieved both from the classpath and + * [[answerCache]] to speed up testing. + * + * See the documentation of public vals in this class for information on how test execution can be + * configured using system properties. + */ +abstract class HiveComparisonTest extends FunSuite with BeforeAndAfterAll with GivenWhenThen with Logging { + + /** + * When set, any cache files that result in test failures will be deleted. Used when the test + * harness or hive have been updated thus requiring new golden answers to be computed for some + * tests. Also prevents the classpath being used when looking for golden answers as these are + * usually stale. + */ + val recomputeCache = System.getProperty("spark.hive.recomputeCache") != null + + protected val shardRegEx = "(\\d+):(\\d+)".r + /** + * Allows multiple JVMs to be run in parallel, each responsible for portion of all test cases. + * Format `shardId:numShards`. Shard ids should be zero indexed. E.g. -Dspark.hive.testshard=0:4. + */ + val shardInfo = Option(System.getProperty("spark.hive.shard")).map { + case shardRegEx(id, total) => (id.toInt, total.toInt) + } + + protected val targetDir = new File("target") + + /** + * When set, this comma separated list is defines directories that contain the names of test cases + * that should be skipped. + * + * For example when `-Dspark.hive.skiptests=passed,hiveFailed` is specified and test cases listed + * in [[passedDirectory]] or [[hiveFailedDirectory]] will be skipped. + */ + val skipDirectories = + Option(System.getProperty("spark.hive.skiptests")) + .toSeq + .flatMap(_.split(",")) + .map(name => new File(targetDir, s"$suiteName.$name")) + + val runOnlyDirectories = + Option(System.getProperty("spark.hive.runonlytests")) + .toSeq + .flatMap(_.split(",")) + .map(name => new File(targetDir, s"$suiteName.$name")) + + /** The local directory with cached golden answer will be stored. */ + protected val answerCache = new File("src/test/resources/golden") + if (!answerCache.exists) { + answerCache.mkdir() + } + + /** The [[ClassLoader]] that contains test dependencies. Used to look for golden answers. */ + protected val testClassLoader = this.getClass.getClassLoader + + /** Directory containing a file for each test case that passes. */ + val passedDirectory = new File(targetDir, s"$suiteName.passed") + if (!passedDirectory.exists()) { + passedDirectory.mkdir() // Not atomic! + } + + /** Directory containing output of tests that fail to execute with Catalyst. */ + val failedDirectory = new File(targetDir, s"$suiteName.failed") + if (!failedDirectory.exists()) { + failedDirectory.mkdir() // Not atomic! + } + + /** Directory containing output of tests where catalyst produces the wrong answer. */ + val wrongDirectory = new File(targetDir, s"$suiteName.wrong") + if (!wrongDirectory.exists()) { + wrongDirectory.mkdir() // Not atomic! + } + + /** Directory containing output of tests where we fail to generate golden output with Hive. */ + val hiveFailedDirectory = new File(targetDir, s"$suiteName.hiveFailed") + if (!hiveFailedDirectory.exists()) { + hiveFailedDirectory.mkdir() // Not atomic! + } + + /** All directories that contain per-query output files */ + val outputDirectories = Seq( + passedDirectory, + failedDirectory, + wrongDirectory, + hiveFailedDirectory) + + protected val cacheDigest = java.security.MessageDigest.getInstance("MD5") + protected def getMd5(str: String): String = { + val digest = java.security.MessageDigest.getInstance("MD5") + digest.update(str.getBytes) + new java.math.BigInteger(1, digest.digest).toString(16) + } + + protected def prepareAnswer( + hiveQuery: TestHive.type#SqlQueryExecution, + answer: Seq[String]): Seq[String] = { + val orderedAnswer = hiveQuery.logical match { + // Clean out non-deterministic time schema info. + case _: NativeCommand => answer.filterNot(nonDeterministicLine).filterNot(_ == "") + case _: ExplainCommand => answer + case _ => + // TODO: Really we only care about the final total ordering here... + val isOrdered = hiveQuery.executedPlan.collect { + case s @ Sort(_, global, _) if global => s + }.nonEmpty + // If the query results aren't sorted, then sort them to ensure deterministic answers. + if (!isOrdered) answer.sorted else answer + } + orderedAnswer.map(cleanPaths) + } + + // TODO: Instead of filtering we should clean to avoid accidentally ignoring actual results. + lazy val nonDeterministicLineIndicators = Seq( + "CreateTime", + "transient_lastDdlTime", + "grantTime", + "lastUpdateTime", + "last_modified_time", + "Owner:", + // The following are hive specific schema parameters which we do not need to match exactly. + "numFiles", + "numRows", + "rawDataSize", + "totalSize", + "totalNumberFiles", + "maxFileSize", + "minFileSize" + ) + protected def nonDeterministicLine(line: String) = + nonDeterministicLineIndicators.map(line contains _).reduceLeft(_||_) + + /** + * Removes non-deterministic paths from `str` so cached answers will compare correctly. + */ + protected def cleanPaths(str: String): String = { + str.replaceAll("file:\\/.*\\/", "") + } + + val installHooksCommand = "(?i)SET.*hooks".r + def createQueryTest(testCaseName: String, sql: String) { + // If test sharding is enable, skip tests that are not in the correct shard. + shardInfo.foreach { + case (shardId, numShards) if testCaseName.hashCode % numShards != shardId => return + case (shardId, _) => logger.debug(s"Shard $shardId includes test '$testCaseName'") + } + + // Skip tests found in directories specified by user. + skipDirectories + .map(new File(_, testCaseName)) + .filter(_.exists) + .foreach(_ => return) + + // If runonlytests is set, skip this test unless we find a file in one of the specified + // directories. + val runIndicators = + runOnlyDirectories + .map(new File(_, testCaseName)) + .filter(_.exists) + if (runOnlyDirectories.nonEmpty && runIndicators.isEmpty) { + logger.debug( + s"Skipping test '$testCaseName' not found in ${runOnlyDirectories.map(_.getCanonicalPath)}") + return + } + + test(testCaseName) { + logger.debug(s"=== HIVE TEST: $testCaseName ===") + + // Clear old output for this testcase. + outputDirectories.map(new File(_, testCaseName)).filter(_.exists()).foreach(_.delete()) + + val allQueries = sql.split("(?<=[^\\\\]);").map(_.trim).filterNot(q => q == "").toSeq + + // TODO: DOCUMENT UNSUPPORTED + val queryList = + allQueries + // In hive, setting the hive.outerjoin.supports.filters flag to "false" essentially tells + // the system to return the wrong answer. Since we have no intention of mirroring their + // previously broken behavior we simply filter out changes to this setting. + .filterNot(_ contains "hive.outerjoin.supports.filters") + + if (allQueries != queryList) + logger.warn(s"Simplifications made on unsupported operations for test $testCaseName") + + lazy val consoleTestCase = { + val quotes = "\"\"\"" + queryList.zipWithIndex.map { + case (query, i) => + s""" + |val q$i = $quotes$query$quotes.q + |q$i.stringResult() + """.stripMargin + }.mkString("\n== Console version of this test ==\n", "\n", "\n") + } + + try { + // MINOR HACK: You must run a query before calling reset the first time. + TestHive.sql("SHOW TABLES") + TestHive.reset() + + val hiveCacheFiles = queryList.zipWithIndex.map { + case (queryString, i) => + val cachedAnswerName = s"$testCaseName-$i-${getMd5(queryString)}" + new File(answerCache, cachedAnswerName) + } + + val hiveCachedResults = hiveCacheFiles.flatMap { cachedAnswerFile => + logger.debug(s"Looking for cached answer file $cachedAnswerFile.") + if (cachedAnswerFile.exists) { + Some(fileToString(cachedAnswerFile)) + } else { + logger.debug(s"File $cachedAnswerFile not found") + None + } + }.map { + case "" => Nil + case "\n" => Seq("") + case other => other.split("\n").toSeq + } + + val hiveResults: Seq[Seq[String]] = + if (hiveCachedResults.size == queryList.size) { + logger.info(s"Using answer cache for test: $testCaseName") + hiveCachedResults + } else { + + val hiveQueries = queryList.map(new TestHive.SqlQueryExecution(_)) + // Make sure we can at least parse everything before attempting hive execution. + hiveQueries.foreach(_.logical) + val computedResults = (queryList.zipWithIndex, hiveQueries, hiveCacheFiles).zipped.map { + case ((queryString, i), hiveQuery, cachedAnswerFile)=> + try { + // Hooks often break the harness and don't really affect our test anyway, don't + // even try running them. + if (installHooksCommand.findAllMatchIn(queryString).nonEmpty) + sys.error("hive exec hooks not supported for tests.") + + logger.warn(s"Running query ${i+1}/${queryList.size} with hive.") + // Analyze the query with catalyst to ensure test tables are loaded. + val answer = hiveQuery.analyzed match { + case _: ExplainCommand => Nil // No need to execute EXPLAIN queries as we don't check the output. + case _ => TestHive.runSqlHive(queryString) + } + + // We need to add a new line to non-empty answers so we can differentiate Seq() + // from Seq(""). + stringToFile( + cachedAnswerFile, answer.mkString("\n") + (if (answer.nonEmpty) "\n" else "")) + answer + } catch { + case e: Exception => + val errorMessage = + s""" + |Failed to generate golden answer for query: + |Error: ${e.getMessage} + |${stackTraceToString(e)} + |$queryString + |$consoleTestCase + """.stripMargin + stringToFile( + new File(hiveFailedDirectory, testCaseName), + errorMessage + consoleTestCase) + fail(errorMessage) + } + }.toSeq + TestHive.reset() + + computedResults + } + + // Run w/ catalyst + val catalystResults = queryList.zip(hiveResults).map { case (queryString, hive) => + val query = new TestHive.SqlQueryExecution(queryString) + try { (query, prepareAnswer(query, query.stringResult())) } catch { + case e: Exception => + val errorMessage = + s""" + |Failed to execute query using catalyst: + |Error: ${e.getMessage} + |${stackTraceToString(e)} + |$query + |== HIVE - ${hive.size} row(s) == + |${hive.mkString("\n")} + | + |$consoleTestCase + """.stripMargin + stringToFile(new File(failedDirectory, testCaseName), errorMessage + consoleTestCase) + fail(errorMessage) + } + }.toSeq + + (queryList, hiveResults, catalystResults).zipped.foreach { + case (query, hive, (hiveQuery, catalyst)) => + // Check that the results match unless its an EXPLAIN query. + val preparedHive = prepareAnswer(hiveQuery,hive) + + if ((!hiveQuery.logical.isInstanceOf[ExplainCommand]) && preparedHive != catalyst) { + + val hivePrintOut = s"== HIVE - ${hive.size} row(s) ==" +: preparedHive + val catalystPrintOut = s"== CATALYST - ${catalyst.size} row(s) ==" +: catalyst + + val resultComparison = sideBySide(hivePrintOut, catalystPrintOut).mkString("\n") + + if (recomputeCache) { + logger.warn(s"Clearing cache files for failed test $testCaseName") + hiveCacheFiles.foreach(_.delete()) + } + + val errorMessage = + s""" + |Results do not match for $testCaseName: + |$hiveQuery\n${hiveQuery.analyzed.output.map(_.name).mkString("\t")} + |$resultComparison + """.stripMargin + + stringToFile(new File(wrongDirectory, testCaseName), errorMessage + consoleTestCase) + fail(errorMessage) + } + } + + // Touch passed file. + new FileOutputStream(new File(passedDirectory, testCaseName)).close() + } catch { + case tf: org.scalatest.exceptions.TestFailedException => throw tf + case originalException: Exception => + if (System.getProperty("spark.hive.canarytest") != null) { + // When we encounter an error we check to see if the environment is still okay by running a simple query. + // If this fails then we halt testing since something must have gone seriously wrong. + try { + new TestHive.SqlQueryExecution("SELECT key FROM src").stringResult() + TestHive.runSqlHive("SELECT key FROM src") + } catch { + case e: Exception => + logger.error(s"FATAL ERROR: Canary query threw $e This implies that the testing environment has likely been corrupted.") + // The testing setup traps exits so wait here for a long time so the developer can see when things started + // to go wrong. + Thread.sleep(1000000) + } + } + + // If the canary query didn't fail then the environment is still okay, so just throw the original exception. + throw originalException + } + } + } +} \ No newline at end of file diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala new file mode 100644 index 000000000..d010023f7 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -0,0 +1,708 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package hive +package execution + + +import java.io._ + +import util._ + +/** + * Runs the test cases that are included in the hive distribution. + */ +class HiveCompatibilitySuite extends HiveQueryFileTest { + // TODO: bundle in jar files... get from classpath + lazy val hiveQueryDir = TestHive.getHiveFile("ql/src/test/queries/clientpositive") + def testCases = hiveQueryDir.listFiles.map(f => f.getName.stripSuffix(".q") -> f) + + /** A list of tests deemed out of scope currently and thus completely disregarded. */ + override def blackList = Seq( + // These tests use hooks that are not on the classpath and thus break all subsequent execution. + "hook_order", + "hook_context", + "mapjoin_hook", + "multi_sahooks", + "overridden_confs", + "query_properties", + "sample10", + "updateAccessTime", + "index_compact_binary_search", + "bucket_num_reducers", + "column_access_stats", + "concatenate_inherit_table_location", + + // Setting a default property does not seem to get reset and thus changes the answer for many + // subsequent tests. + "create_default_prop", + + // User/machine specific test answers, breaks the caching mechanism. + "authorization_3", + "authorization_5", + "keyword_1", + "misc_json", + "create_like_tbl_props", + "load_overwrite", + "alter_table_serde2", + "alter_table_not_sorted", + "alter_skewed_table", + "alter_partition_clusterby_sortby", + "alter_merge", + "alter_concatenate_indexed_table", + "protectmode2", + "describe_table", + "describe_comment_nonascii", + "udf5", + "udf_java_method", + + // Weird DDL differences result in failures on jenkins. + "create_like2", + "create_view_translate", + "partitions_json", + + // Timezone specific test answers. + "udf_unix_timestamp", + "udf_to_unix_timestamp", + + // Cant run without local map/reduce. + "index_auto_update", + "index_auto_self_join", + "index_stale.*", + "type_cast_1", + "index_compression", + "index_bitmap_compression", + "index_auto_multiple", + "index_auto_mult_tables_compact", + "index_auto_mult_tables", + "index_auto_file_format", + "index_auth", + "index_auto_empty", + "index_auto_partitioned", + "index_auto_unused", + "index_bitmap_auto_partitioned", + "ql_rewrite_gbtoidx", + "stats1.*", + "stats20", + "alter_merge_stats", + + // Hive seems to think 1.0 > NaN = true && 1.0 < NaN = false... which is wrong. + // http://stackoverflow.com/a/1573715 + "ops_comparison", + + // Tests that seems to never complete on hive... + "skewjoin", + "database", + + // These tests fail and and exit the JVM. + "auto_join18_multi_distinct", + "join18_multi_distinct", + "input44", + "input42", + "input_dfs", + "metadata_export_drop", + "repair", + + // Uses a serde that isn't on the classpath... breaks other tests. + "bucketizedhiveinputformat", + + // Avro tests seem to change the output format permanently thus breaking the answer cache, until + // we figure out why this is the case let just ignore all of avro related tests. + ".*avro.*", + + // Unique joins are weird and will require a lot of hacks (see comments in hive parser). + "uniquejoin", + + // Hive seems to get the wrong answer on some outer joins. MySQL agrees with catalyst. + "auto_join29", + + // No support for multi-alias i.e. udf as (e1, e2, e3). + "allcolref_in_udf", + + // No support for TestSerDe (not published afaik) + "alter1", + "input16", + + // No support for unpublished test udfs. + "autogen_colalias", + + // Hive does not support buckets. + ".*bucket.*", + + // No window support yet + ".*window.*", + + // Fails in hive with authorization errors. + "alter_rename_partition_authorization", + "authorization.*", + + // Hadoop version specific tests + "archive_corrupt", + + // No support for case sensitivity is resolution using hive properties atm. + "case_sensitivity" + ) + + /** + * The set of tests that are believed to be working in catalyst. Tests not on whiteList or + * blacklist are implicitly marked as ignored. + */ + override def whiteList = Seq( + "add_part_exist", + "add_partition_no_whitelist", + "add_partition_with_whitelist", + "alias_casted_column", + "alter2", + "alter4", + "alter5", + "alter_index", + "alter_merge_2", + "alter_partition_format_loc", + "alter_partition_protect_mode", + "alter_partition_with_whitelist", + "alter_table_serde", + "alter_varchar2", + "alter_view_as_select", + "ambiguous_col", + "auto_join0", + "auto_join1", + "auto_join10", + "auto_join11", + "auto_join12", + "auto_join13", + "auto_join14", + "auto_join14_hadoop20", + "auto_join15", + "auto_join17", + "auto_join18", + "auto_join19", + "auto_join2", + "auto_join20", + "auto_join21", + "auto_join22", + "auto_join23", + "auto_join24", + "auto_join25", + "auto_join26", + "auto_join27", + "auto_join28", + "auto_join3", + "auto_join30", + "auto_join31", + "auto_join32", + "auto_join4", + "auto_join5", + "auto_join6", + "auto_join7", + "auto_join8", + "auto_join9", + "auto_join_filters", + "auto_join_nulls", + "auto_join_reordering_values", + "auto_sortmerge_join_1", + "auto_sortmerge_join_10", + "auto_sortmerge_join_11", + "auto_sortmerge_join_12", + "auto_sortmerge_join_15", + "auto_sortmerge_join_2", + "auto_sortmerge_join_3", + "auto_sortmerge_join_4", + "auto_sortmerge_join_5", + "auto_sortmerge_join_6", + "auto_sortmerge_join_7", + "auto_sortmerge_join_8", + "auto_sortmerge_join_9", + "binary_constant", + "binarysortable_1", + "combine1", + "compute_stats_binary", + "compute_stats_boolean", + "compute_stats_double", + "compute_stats_table", + "compute_stats_long", + "compute_stats_string", + "convert_enum_to_string", + "correlationoptimizer11", + "correlationoptimizer15", + "correlationoptimizer2", + "correlationoptimizer3", + "correlationoptimizer4", + "correlationoptimizer6", + "correlationoptimizer7", + "correlationoptimizer8", + "count", + "create_like_view", + "create_nested_type", + "create_skewed_table1", + "create_struct_table", + "ct_case_insensitive", + "database_location", + "database_properties", + "decimal_join", + "default_partition_name", + "delimiter", + "desc_non_existent_tbl", + "describe_comment_indent", + "describe_database_json", + "describe_pretty", + "describe_syntax", + "describe_table_json", + "diff_part_input_formats", + "disable_file_format_check", + "drop_function", + "drop_index", + "drop_partitions_filter", + "drop_partitions_filter2", + "drop_partitions_filter3", + "drop_partitions_ignore_protection", + "drop_table", + "drop_table2", + "drop_view", + "escape_clusterby1", + "escape_distributeby1", + "escape_orderby1", + "escape_sortby1", + "fetch_aggregation", + "filter_join_breaktask", + "filter_join_breaktask2", + "groupby1", + "groupby11", + "groupby1_map", + "groupby1_map_nomap", + "groupby1_map_skew", + "groupby1_noskew", + "groupby4", + "groupby4_map", + "groupby4_map_skew", + "groupby4_noskew", + "groupby5", + "groupby5_map", + "groupby5_map_skew", + "groupby5_noskew", + "groupby6", + "groupby6_map", + "groupby6_map_skew", + "groupby6_noskew", + "groupby7", + "groupby7_map", + "groupby7_map_multi_single_reducer", + "groupby7_map_skew", + "groupby7_noskew", + "groupby8_map", + "groupby8_map_skew", + "groupby8_noskew", + "groupby_distinct_samekey", + "groupby_multi_single_reducer2", + "groupby_mutli_insert_common_distinct", + "groupby_neg_float", + "groupby_sort_10", + "groupby_sort_6", + "groupby_sort_8", + "groupby_sort_test_1", + "implicit_cast1", + "innerjoin", + "inoutdriver", + "input", + "input0", + "input11", + "input11_limit", + "input12", + "input12_hadoop20", + "input19", + "input1_limit", + "input22", + "input23", + "input24", + "input25", + "input26", + "input28", + "input2_limit", + "input40", + "input41", + "input4_cb_delim", + "input6", + "input7", + "input8", + "input9", + "input_limit", + "input_part0", + "input_part1", + "input_part10", + "input_part10_win", + "input_part2", + "input_part3", + "input_part4", + "input_part5", + "input_part6", + "input_part7", + "input_part8", + "input_part9", + "inputddl4", + "inputddl7", + "inputddl8", + "insert_compressed", + "join0", + "join1", + "join10", + "join11", + "join12", + "join13", + "join14", + "join14_hadoop20", + "join15", + "join16", + "join17", + "join18", + "join19", + "join2", + "join20", + "join21", + "join22", + "join23", + "join24", + "join25", + "join26", + "join27", + "join28", + "join29", + "join3", + "join30", + "join31", + "join32", + "join33", + "join34", + "join35", + "join36", + "join37", + "join38", + "join39", + "join4", + "join40", + "join41", + "join5", + "join6", + "join7", + "join8", + "join9", + "join_1to1", + "join_array", + "join_casesensitive", + "join_empty", + "join_filters", + "join_hive_626", + "join_nulls", + "join_reorder2", + "join_reorder3", + "join_reorder4", + "join_star", + "join_view", + "lateral_view_cp", + "lateral_view_ppd", + "lineage1", + "literal_double", + "literal_ints", + "literal_string", + "load_dyn_part7", + "load_file_with_space_in_the_name", + "louter_join_ppr", + "mapjoin_distinct", + "mapjoin_mapjoin", + "mapjoin_subquery", + "mapjoin_subquery2", + "mapjoin_test_outer", + "mapreduce3", + "mapreduce7", + "merge1", + "merge2", + "mergejoins", + "mergejoins_mixed", + "multiMapJoin1", + "multiMapJoin2", + "multi_join_union", + "multigroupby_singlemr", + "noalias_subq1", + "nomore_ambiguous_table_col", + "nonblock_op_deduplicate", + "notable_alias1", + "notable_alias2", + "nullgroup", + "nullgroup2", + "nullgroup3", + "nullgroup4", + "nullgroup4_multi_distinct", + "nullgroup5", + "nullinput", + "nullinput2", + "nullscript", + "optional_outer", + "order", + "order2", + "outer_join_ppr", + "part_inherit_tbl_props", + "part_inherit_tbl_props_empty", + "part_inherit_tbl_props_with_star", + "partition_schema1", + "partition_varchar1", + "plan_json", + "ppd1", + "ppd_constant_where", + "ppd_gby", + "ppd_gby2", + "ppd_gby_join", + "ppd_join", + "ppd_join2", + "ppd_join3", + "ppd_join_filter", + "ppd_outer_join1", + "ppd_outer_join2", + "ppd_outer_join3", + "ppd_outer_join4", + "ppd_outer_join5", + "ppd_random", + "ppd_repeated_alias", + "ppd_udf_col", + "ppd_union", + "ppr_allchildsarenull", + "ppr_pushdown", + "ppr_pushdown2", + "ppr_pushdown3", + "progress_1", + "protectmode", + "push_or", + "query_with_semi", + "quote1", + "quote2", + "reduce_deduplicate_exclude_join", + "rename_column", + "router_join_ppr", + "select_as_omitted", + "select_unquote_and", + "select_unquote_not", + "select_unquote_or", + "serde_reported_schema", + "set_variable_sub", + "show_describe_func_quotes", + "show_functions", + "show_partitions", + "skewjoinopt13", + "skewjoinopt18", + "skewjoinopt9", + "smb_mapjoin_1", + "smb_mapjoin_10", + "smb_mapjoin_13", + "smb_mapjoin_14", + "smb_mapjoin_15", + "smb_mapjoin_16", + "smb_mapjoin_17", + "smb_mapjoin_2", + "smb_mapjoin_21", + "smb_mapjoin_25", + "smb_mapjoin_3", + "smb_mapjoin_4", + "smb_mapjoin_5", + "smb_mapjoin_8", + "sort", + "sort_merge_join_desc_1", + "sort_merge_join_desc_2", + "sort_merge_join_desc_3", + "sort_merge_join_desc_4", + "sort_merge_join_desc_5", + "sort_merge_join_desc_6", + "sort_merge_join_desc_7", + "stats0", + "stats_empty_partition", + "subq2", + "tablename_with_select", + "touch", + "type_widening", + "udaf_collect_set", + "udaf_corr", + "udaf_covar_pop", + "udaf_covar_samp", + "udf2", + "udf6", + "udf9", + "udf_10_trims", + "udf_E", + "udf_PI", + "udf_abs", + "udf_acos", + "udf_add", + "udf_array", + "udf_array_contains", + "udf_ascii", + "udf_asin", + "udf_atan", + "udf_avg", + "udf_bigint", + "udf_bin", + "udf_bitmap_and", + "udf_bitmap_empty", + "udf_bitmap_or", + "udf_bitwise_and", + "udf_bitwise_not", + "udf_bitwise_or", + "udf_bitwise_xor", + "udf_boolean", + "udf_case", + "udf_ceil", + "udf_ceiling", + "udf_concat", + "udf_concat_insert2", + "udf_concat_ws", + "udf_conv", + "udf_cos", + "udf_count", + "udf_date_add", + "udf_date_sub", + "udf_datediff", + "udf_day", + "udf_dayofmonth", + "udf_degrees", + "udf_div", + "udf_double", + "udf_exp", + "udf_field", + "udf_find_in_set", + "udf_float", + "udf_floor", + "udf_format_number", + "udf_from_unixtime", + "udf_greaterthan", + "udf_greaterthanorequal", + "udf_hex", + "udf_if", + "udf_index", + "udf_int", + "udf_isnotnull", + "udf_isnull", + "udf_java_method", + "udf_lcase", + "udf_length", + "udf_lessthan", + "udf_lessthanorequal", + "udf_like", + "udf_ln", + "udf_log", + "udf_log10", + "udf_log2", + "udf_lower", + "udf_lpad", + "udf_ltrim", + "udf_map", + "udf_minute", + "udf_modulo", + "udf_month", + "udf_negative", + "udf_not", + "udf_notequal", + "udf_notop", + "udf_nvl", + "udf_or", + "udf_parse_url", + "udf_positive", + "udf_pow", + "udf_power", + "udf_radians", + "udf_rand", + "udf_regexp", + "udf_regexp_extract", + "udf_regexp_replace", + "udf_repeat", + "udf_rlike", + "udf_round", + "udf_round_3", + "udf_rpad", + "udf_rtrim", + "udf_second", + "udf_sign", + "udf_sin", + "udf_smallint", + "udf_space", + "udf_sqrt", + "udf_std", + "udf_stddev", + "udf_stddev_pop", + "udf_stddev_samp", + "udf_string", + "udf_substring", + "udf_subtract", + "udf_sum", + "udf_tan", + "udf_tinyint", + "udf_to_byte", + "udf_to_date", + "udf_to_double", + "udf_to_float", + "udf_to_long", + "udf_to_short", + "udf_translate", + "udf_trim", + "udf_ucase", + "udf_upper", + "udf_var_pop", + "udf_var_samp", + "udf_variance", + "udf_weekofyear", + "udf_when", + "udf_xpath", + "udf_xpath_boolean", + "udf_xpath_double", + "udf_xpath_float", + "udf_xpath_int", + "udf_xpath_long", + "udf_xpath_short", + "udf_xpath_string", + "unicode_notation", + "union10", + "union11", + "union13", + "union14", + "union15", + "union16", + "union17", + "union18", + "union19", + "union2", + "union20", + "union22", + "union23", + "union24", + "union26", + "union27", + "union28", + "union29", + "union30", + "union31", + "union34", + "union4", + "union5", + "union6", + "union7", + "union8", + "union9", + "union_lateralview", + "union_ppr", + "union_remove_3", + "union_remove_6", + "union_script", + "varchar_2", + "varchar_join1", + "varchar_union1" + ) +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala new file mode 100644 index 000000000..f0a4ec3c0 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package hive +package execution + +import java.io._ + +import catalyst.util._ + +/** + * A framework for running the query tests that are listed as a set of text files. + * + * TestSuites that derive from this class must provide a map of testCaseName -> testCaseFiles that should be included. + * Additionally, there is support for whitelisting and blacklisting tests as development progresses. + */ +abstract class HiveQueryFileTest extends HiveComparisonTest { + /** A list of tests deemed out of scope and thus completely disregarded */ + def blackList: Seq[String] = Nil + + /** + * The set of tests that are believed to be working in catalyst. Tests not in whiteList + * blacklist are implicitly marked as ignored. + */ + def whiteList: Seq[String] = ".*" :: Nil + + def testCases: Seq[(String, File)] + + val runAll = + !(System.getProperty("spark.hive.alltests") == null) || + runOnlyDirectories.nonEmpty || + skipDirectories.nonEmpty + + val whiteListProperty = "spark.hive.whitelist" + // Allow the whiteList to be overridden by a system property + val realWhiteList = + Option(System.getProperty(whiteListProperty)).map(_.split(",").toSeq).getOrElse(whiteList) + + // Go through all the test cases and add them to scala test. + testCases.sorted.foreach { + case (testCaseName, testCaseFile) => + if (blackList.map(_.r.pattern.matcher(testCaseName).matches()).reduceLeft(_||_)) { + logger.debug(s"Blacklisted test skipped $testCaseName") + } else if (realWhiteList.map(_.r.pattern.matcher(testCaseName).matches()).reduceLeft(_||_) || runAll) { + // Build a test case and submit it to scala test framework... + val queriesString = fileToString(testCaseFile) + createQueryTest(testCaseName, queriesString) + } else { + // Only output warnings for the built in whitelist as this clutters the output when the user + // trying to execute a single test from the commandline. + if(System.getProperty(whiteListProperty) == null && !runAll) + ignore(testCaseName) {} + } + } +} \ No newline at end of file diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala new file mode 100644 index 000000000..28a5d260b --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package hive +package execution + + +/** + * A set of test cases expressed in Hive QL that are not covered by the tests included in the hive distribution. + */ +class HiveQuerySuite extends HiveComparisonTest { + import TestHive._ + + createQueryTest("Simple Average", + "SELECT AVG(key) FROM src") + + createQueryTest("Simple Average + 1", + "SELECT AVG(key) + 1.0 FROM src") + + createQueryTest("Simple Average + 1 with group", + "SELECT AVG(key) + 1.0, value FROM src group by value") + + createQueryTest("string literal", + "SELECT 'test' FROM src") + + createQueryTest("Escape sequences", + """SELECT key, '\\\t\\' FROM src WHERE key = 86""") + + createQueryTest("IgnoreExplain", + """EXPLAIN SELECT key FROM src""") + + createQueryTest("trivial join where clause", + "SELECT * FROM src a JOIN src b WHERE a.key = b.key") + + createQueryTest("trivial join ON clause", + "SELECT * FROM src a JOIN src b ON a.key = b.key") + + createQueryTest("small.cartesian", + "SELECT a.key, b.key FROM (SELECT key FROM src WHERE key < 1) a JOIN (SELECT key FROM src WHERE key = 2) b") + + createQueryTest("length.udf", + "SELECT length(\"test\") FROM src LIMIT 1") + + ignore("partitioned table scan") { + createQueryTest("partitioned table scan", + "SELECT ds, hr, key, value FROM srcpart") + } + + createQueryTest("hash", + "SELECT hash('test') FROM src LIMIT 1") + + createQueryTest("create table as", + """ + |CREATE TABLE createdtable AS SELECT * FROM src; + |SELECT * FROM createdtable + """.stripMargin) + + createQueryTest("create table as with db name", + """ + |CREATE DATABASE IF NOT EXISTS testdb; + |CREATE TABLE testdb.createdtable AS SELECT * FROM default.src; + |SELECT * FROM testdb.createdtable; + |DROP DATABASE IF EXISTS testdb CASCADE + """.stripMargin) + + createQueryTest("insert table with db name", + """ + |CREATE DATABASE IF NOT EXISTS testdb; + |CREATE TABLE testdb.createdtable like default.src; + |INSERT INTO TABLE testdb.createdtable SELECT * FROM default.src; + |SELECT * FROM testdb.createdtable; + |DROP DATABASE IF EXISTS testdb CASCADE + """.stripMargin) + + createQueryTest("insert into and insert overwrite", + """ + |CREATE TABLE createdtable like src; + |INSERT INTO TABLE createdtable SELECT * FROM src; + |INSERT INTO TABLE createdtable SELECT * FROM src1; + |SELECT * FROM createdtable; + |INSERT OVERWRITE TABLE createdtable SELECT * FROM src WHERE key = 86; + |SELECT * FROM createdtable; + """.stripMargin) + + createQueryTest("transform", + "SELECT TRANSFORM (key) USING 'cat' AS (tKey) FROM src") + + createQueryTest("LIKE", + "SELECT * FROM src WHERE value LIKE '%1%'") + + createQueryTest("DISTINCT", + "SELECT DISTINCT key, value FROM src") + + ignore("empty aggregate input") { + createQueryTest("empty aggregate input", + "SELECT SUM(key) FROM (SELECT * FROM src LIMIT 0) a") + } + + createQueryTest("lateral view1", + "SELECT tbl.* FROM src LATERAL VIEW explode(array(1,2)) tbl as a") + + createQueryTest("lateral view2", + "SELECT * FROM src LATERAL VIEW explode(array(1,2)) tbl") + + + createQueryTest("lateral view3", + "FROM src SELECT key, D.* lateral view explode(array(key+3, key+4)) D as CX") + + createQueryTest("lateral view4", + """ + |create table src_lv1 (key string, value string); + |create table src_lv2 (key string, value string); + | + |FROM src + |insert overwrite table src_lv1 SELECT key, D.* lateral view explode(array(key+3, key+4)) D as CX + |insert overwrite table src_lv2 SELECT key, D.* lateral view explode(array(key+3, key+4)) D as CX + """.stripMargin) + + createQueryTest("lateral view5", + "FROM src SELECT explode(array(key+3, key+4))") + + createQueryTest("lateral view6", + "SELECT * FROM src LATERAL VIEW explode(map(key+3,key+4)) D as k, v") + + test("sampling") { + sql("SELECT * FROM src TABLESAMPLE(0.1 PERCENT) s") + } + +} \ No newline at end of file diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala new file mode 100644 index 000000000..0dd79faa1 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package hive +package execution + +/** + * A set of test cases expressed in Hive QL that are not covered by the tests included in the hive distribution. + */ +class HiveResolutionSuite extends HiveComparisonTest { + import TestHive._ + + createQueryTest("table.attr", + "SELECT src.key FROM src ORDER BY key LIMIT 1") + + createQueryTest("database.table", + "SELECT key FROM default.src ORDER BY key LIMIT 1") + + createQueryTest("database.table table.attr", + "SELECT src.key FROM default.src ORDER BY key LIMIT 1") + + createQueryTest("alias.attr", + "SELECT a.key FROM src a ORDER BY key LIMIT 1") + + createQueryTest("subquery-alias.attr", + "SELECT a.key FROM (SELECT * FROM src ORDER BY key LIMIT 1) a") + + createQueryTest("quoted alias.attr", + "SELECT `a`.`key` FROM src a ORDER BY key LIMIT 1") + + createQueryTest("attr", + "SELECT key FROM src a ORDER BY key LIMIT 1") + + createQueryTest("alias.*", + "SELECT a.* FROM src a ORDER BY key LIMIT 1") + + /** + * Negative examples. Currently only left here for documentation purposes. + * TODO(marmbrus): Test that catalyst fails on these queries. + */ + + /* SemanticException [Error 10009]: Line 1:7 Invalid table alias 'src' + createQueryTest("table.*", + "SELECT src.* FROM src a ORDER BY key LIMIT 1") */ + + /* Invalid table alias or column reference 'src': (possible column names are: key, value) + createQueryTest("tableName.attr from aliased subquery", + "SELECT src.key FROM (SELECT * FROM src ORDER BY key LIMIT 1) a") */ + +} \ No newline at end of file diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala new file mode 100644 index 000000000..c2264926f --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package hive +package execution + +/** + * A set of tests that validates support for Hive SerDe. + */ +class HiveSerDeSuite extends HiveComparisonTest { + createQueryTest( + "Read and write with LazySimpleSerDe (tab separated)", + "SELECT * from serdeins") + + createQueryTest("Read with RegexSerDe", "SELECT * FROM sales") + + createQueryTest("Read with AvroSerDe", "SELECT * FROM episodes") +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala new file mode 100644 index 000000000..bb33583e5 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +/** + * A set of tests that validate type promotion rules. + */ +class HiveTypeCoercionSuite extends HiveComparisonTest { + + val baseTypes = Seq("1", "1.0", "1L", "1S", "1Y", "'1'") + + baseTypes.foreach { i => + baseTypes.foreach { j => + createQueryTest(s"$i + $j", s"SELECT $i + $j FROM src LIMIT 1") + } + } +} \ No newline at end of file diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala new file mode 100644 index 000000000..8542f42aa --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql +package hive +package execution + +import scala.collection.JavaConversions._ + +import org.apache.spark.sql.hive.TestHive + +/** + * A set of test cases that validate partition and column pruning. + */ +class PruningSuite extends HiveComparisonTest { + // Column pruning tests + + createPruningTest("Column pruning: with partitioned table", + "SELECT key FROM srcpart WHERE ds = '2008-04-08' LIMIT 3", + Seq("key"), + Seq("key", "ds"), + Seq( + Seq("2008-04-08", "11"), + Seq("2008-04-08", "12"))) + + createPruningTest("Column pruning: with non-partitioned table", + "SELECT key FROM src WHERE key > 10 LIMIT 3", + Seq("key"), + Seq("key"), + Seq.empty) + + createPruningTest("Column pruning: with multiple projects", + "SELECT c1 FROM (SELECT key AS c1 FROM src WHERE key > 10) t1 LIMIT 3", + Seq("c1"), + Seq("key"), + Seq.empty) + + createPruningTest("Column pruning: projects alias substituting", + "SELECT c1 AS c2 FROM (SELECT key AS c1 FROM src WHERE key > 10) t1 LIMIT 3", + Seq("c2"), + Seq("key"), + Seq.empty) + + createPruningTest("Column pruning: filter alias in-lining", + "SELECT c1 FROM (SELECT key AS c1 FROM src WHERE key > 10) t1 WHERE c1 < 100 LIMIT 3", + Seq("c1"), + Seq("key"), + Seq.empty) + + createPruningTest("Column pruning: without filters", + "SELECT c1 FROM (SELECT key AS c1 FROM src) t1 LIMIT 3", + Seq("c1"), + Seq("key"), + Seq.empty) + + createPruningTest("Column pruning: simple top project without aliases", + "SELECT key FROM (SELECT key FROM src WHERE key > 10) t1 WHERE key < 100 LIMIT 3", + Seq("key"), + Seq("key"), + Seq.empty) + + createPruningTest("Column pruning: non-trivial top project with aliases", + "SELECT c1 * 2 AS double FROM (SELECT key AS c1 FROM src WHERE key > 10) t1 LIMIT 3", + Seq("double"), + Seq("key"), + Seq.empty) + + // Partition pruning tests + + createPruningTest("Partition pruning: non-partitioned, non-trivial project", + "SELECT key * 2 AS double FROM src WHERE value IS NOT NULL", + Seq("double"), + Seq("key", "value"), + Seq.empty) + + createPruningTest("Partiton pruning: non-partitioned table", + "SELECT value FROM src WHERE key IS NOT NULL", + Seq("value"), + Seq("value", "key"), + Seq.empty) + + createPruningTest("Partition pruning: with filter on string partition key", + "SELECT value, hr FROM srcpart1 WHERE ds = '2008-04-08'", + Seq("value", "hr"), + Seq("value", "hr", "ds"), + Seq( + Seq("2008-04-08", "11"), + Seq("2008-04-08", "12"))) + + createPruningTest("Partition pruning: with filter on int partition key", + "SELECT value, hr FROM srcpart1 WHERE hr < 12", + Seq("value", "hr"), + Seq("value", "hr"), + Seq( + Seq("2008-04-08", "11"), + Seq("2008-04-09", "11"))) + + createPruningTest("Partition pruning: left only 1 partition", + "SELECT value, hr FROM srcpart1 WHERE ds = '2008-04-08' AND hr < 12", + Seq("value", "hr"), + Seq("value", "hr", "ds"), + Seq( + Seq("2008-04-08", "11"))) + + createPruningTest("Partition pruning: all partitions pruned", + "SELECT value, hr FROM srcpart1 WHERE ds = '2014-01-27' AND hr = 11", + Seq("value", "hr"), + Seq("value", "hr", "ds"), + Seq.empty) + + createPruningTest("Partition pruning: pruning with both column key and partition key", + "SELECT value, hr FROM srcpart1 WHERE value IS NOT NULL AND hr < 12", + Seq("value", "hr"), + Seq("value", "hr"), + Seq( + Seq("2008-04-08", "11"), + Seq("2008-04-09", "11"))) + + def createPruningTest( + testCaseName: String, + sql: String, + expectedOutputColumns: Seq[String], + expectedScannedColumns: Seq[String], + expectedPartValues: Seq[Seq[String]]) = { + test(s"$testCaseName - pruning test") { + val plan = new TestHive.SqlQueryExecution(sql).executedPlan + val actualOutputColumns = plan.output.map(_.name) + val (actualScannedColumns, actualPartValues) = plan.collect { + case p @ HiveTableScan(columns, relation, _) => + val columnNames = columns.map(_.name) + val partValues = p.prunePartitions(relation.hiveQlPartitions).map(_.getValues) + (columnNames, partValues) + }.head + + assert(actualOutputColumns sameElements expectedOutputColumns, "Output columns mismatch") + assert(actualScannedColumns sameElements expectedScannedColumns, "Scanned columns mismatch") + + assert( + actualPartValues.length === expectedPartValues.length, + "Partition value count mismatches") + + for ((actual, expected) <- actualPartValues.zip(expectedPartValues)) { + assert(actual sameElements expected, "Partition values mismatch") + } + } + + // Creates a query test to compare query results generated by Hive and Catalyst. + createQueryTest(s"$testCaseName - query test", sql) + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala new file mode 100644 index 000000000..ee90061c7 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.parquet + +import java.io.File + +import org.scalatest.{BeforeAndAfterEach, BeforeAndAfterAll, FunSuite} + +import org.apache.spark.sql.catalyst.expressions.Row +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation +import org.apache.spark.sql.catalyst.util.getTempFilePath +import org.apache.spark.sql.hive.TestHive + + +class HiveParquetSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAfterEach { + + val filename = getTempFilePath("parquettest").getCanonicalFile.toURI.toString + + // runs a SQL and optionally resolves one Parquet table + def runQuery(querystr: String, tableName: Option[String] = None, filename: Option[String] = None): Array[Row] = { + // call to resolve references in order to get CREATE TABLE AS to work + val query = TestHive + .parseSql(querystr) + val finalQuery = + if (tableName.nonEmpty && filename.nonEmpty) + resolveParquetTable(tableName.get, filename.get, query) + else + query + TestHive.executePlan(finalQuery) + .toRdd + .collect() + } + + // stores a query output to a Parquet file + def storeQuery(querystr: String, filename: String): Unit = { + val query = WriteToFile( + filename, + TestHive.parseSql(querystr)) + TestHive + .executePlan(query) + .stringResult() + } + + /** + * TODO: This function is necessary as long as there is no notion of a Catalog for + * Parquet tables. Once such a thing exists this functionality should be moved there. + */ + def resolveParquetTable(tableName: String, filename: String, plan: LogicalPlan): LogicalPlan = { + TestHive.loadTestTable("src") // may not be loaded now + plan.transform { + case relation @ UnresolvedRelation(databaseName, name, alias) => + if (name == tableName) + ParquetRelation(tableName, filename) + else + relation + case op @ InsertIntoCreatedTable(databaseName, name, child) => + if (name == tableName) { + // note: at this stage the plan is not yet analyzed but Parquet needs to know the schema + // and for that we need the child to be resolved + val relation = ParquetRelation.create( + filename, + TestHive.analyzer(child), + TestHive.sparkContext.hadoopConfiguration, + Some(tableName)) + InsertIntoTable( + relation.asInstanceOf[BaseRelation], + Map.empty, + child, + overwrite = false) + } else + op + } + } + + override def beforeAll() { + // write test data + ParquetTestData.writeFile + // Override initial Parquet test table + TestHive.catalog.registerTable(Some[String]("parquet"), "testsource", ParquetTestData.testData) + } + + override def afterAll() { + ParquetTestData.testFile.delete() + } + + override def beforeEach() { + new File(filename).getAbsoluteFile.delete() + } + + override def afterEach() { + new File(filename).getAbsoluteFile.delete() + } + + test("SELECT on Parquet table") { + val rdd = runQuery("SELECT * FROM parquet.testsource") + assert(rdd != null) + assert(rdd.forall(_.size == 6)) + } + + test("Simple column projection + filter on Parquet table") { + val rdd = runQuery("SELECT myboolean, mylong FROM parquet.testsource WHERE myboolean=true") + assert(rdd.size === 5, "Filter returned incorrect number of rows") + assert(rdd.forall(_.getBoolean(0)), "Filter returned incorrect Boolean field value") + } + + test("Converting Hive to Parquet Table via WriteToFile") { + storeQuery("SELECT * FROM src", filename) + val rddOne = runQuery("SELECT * FROM src").sortBy(_.getInt(0)) + val rddTwo = runQuery("SELECT * from ptable", Some("ptable"), Some(filename)).sortBy(_.getInt(0)) + compareRDDs(rddOne, rddTwo, "src (Hive)", Seq("key:Int", "value:String")) + } + + test("INSERT OVERWRITE TABLE Parquet table") { + storeQuery("SELECT * FROM parquet.testsource", filename) + runQuery("INSERT OVERWRITE TABLE ptable SELECT * FROM parquet.testsource", Some("ptable"), Some(filename)) + runQuery("INSERT OVERWRITE TABLE ptable SELECT * FROM parquet.testsource", Some("ptable"), Some(filename)) + val rddCopy = runQuery("SELECT * FROM ptable", Some("ptable"), Some(filename)) + val rddOrig = runQuery("SELECT * FROM parquet.testsource") + compareRDDs(rddOrig, rddCopy, "parquet.testsource", ParquetTestData.testSchemaFieldNames) + } + + test("CREATE TABLE AS Parquet table") { + runQuery("CREATE TABLE ptable AS SELECT * FROM src", Some("ptable"), Some(filename)) + val rddCopy = runQuery("SELECT * FROM ptable", Some("ptable"), Some(filename)) + .sortBy[Int](_.apply(0) match { + case x: Int => x + case _ => 0 + }) + val rddOrig = runQuery("SELECT * FROM src").sortBy(_.getInt(0)) + compareRDDs(rddOrig, rddCopy, "src (Hive)", Seq("key:Int", "value:String")) + } + + private def compareRDDs(rddOne: Array[Row], rddTwo: Array[Row], tableName: String, fieldNames: Seq[String]) { + var counter = 0 + (rddOne, rddTwo).zipped.foreach { + (a,b) => (a,b).zipped.toArray.zipWithIndex.foreach { + case ((value_1:Array[Byte], value_2:Array[Byte]), index) => + assert(new String(value_1) === new String(value_2), s"table $tableName row ${counter} field ${fieldNames(index)} don't match") + case ((value_1, value_2), index) => + assert(value_1 === value_2, s"table $tableName row $counter field ${fieldNames(index)} don't match") + } + counter = counter + 1 + } + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 5847b95e3..062b888e8 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -201,7 +201,7 @@ class StreamingContext private[streaming] ( /** * Create an input stream with any arbitrary user implemented network receiver. - * Find more details at: http://spark-project.org/docs/latest/streaming-custom-receivers.html + * Find more details at: http://spark.apache.org/docs/latest/streaming-custom-receivers.html * @param receiver Custom implementation of NetworkReceiver */ def networkStream[T: ClassTag]( @@ -211,7 +211,7 @@ class StreamingContext private[streaming] ( /** * Create an input stream with any arbitrary user implemented actor receiver. - * Find more details at: http://spark-project.org/docs/latest/streaming-custom-receivers.html + * Find more details at: http://spark.apache.org/docs/latest/streaming-custom-receivers.html * @param props Props object defining creation of the actor * @param name Name of the actor * @param storageLevel RDD storage level. Defaults to memory-only. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receivers/ActorReceiver.scala b/streaming/src/main/scala/org/apache/spark/streaming/receivers/ActorReceiver.scala index 9c5b177c1..bd78bae8a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receivers/ActorReceiver.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receivers/ActorReceiver.scala @@ -46,7 +46,7 @@ object ReceiverSupervisorStrategy { * A receiver trait to be mixed in with your Actor to gain access to * the API for pushing received data into Spark Streaming for being processed. * - * Find more details at: http://spark-project.org/docs/latest/streaming-custom-receivers.html + * Find more details at: http://spark.apache.org/docs/latest/streaming-custom-receivers.html * * @example {{{ * class MyActor extends Actor with Receiver{