diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index 207d7a352c7b3..d09b20d4beb5b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -196,4 +196,15 @@ object TestData { :: ComplexData(Map("2" -> 2), TestData(2, "2"), Seq(2), false) :: Nil).toDF() complexData.registerTempTable("complexData") + + case class GroupData(g: String, v: Int) + val groupData = + TestSQLContext.sparkContext.parallelize( + GroupData("red", 1) :: + GroupData("red", 2) :: + GroupData("blue", 10) :: + GroupData("green", 100) :: + GroupData("green", 200) :: Nil).toDF() + groupData.registerTempTable("groupData") + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 703a34c47ec20..464d125663406 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -21,6 +21,7 @@ package org.apache.spark.sql case class FunctionResult(f1: String, f2: String) class UDFSuite extends QueryTest { + import org.apache.spark.sql.TestData._ private lazy val ctx = org.apache.spark.sql.test.TestSQLContext import ctx.implicits._ @@ -82,6 +83,48 @@ class UDFSuite extends QueryTest { assert(ctx.sql("SELECT strLenScala('test', 1)").head().getInt(0) === 5) } + test("UDF in a WHERE") { + testData.sqlContext.udf.register("oneArgFilter", (n:Int) => { n > 80 }) + + val result = + testData.sqlContext.sql("SELECT * FROM testData WHERE oneArgFilter(key)") + assert(result.count() === 20) + } + + test("UDF in a HAVING") { + testData.sqlContext.udf.register("havingFilter", (n:Long) => { n > 5 }) + + val result = + testData.sqlContext.sql("SELECT g, SUM(v) as s FROM groupData GROUP BY g HAVING havingFilter(s)") + assert(result.count() === 2) + } + + test("UDF in a GROUP BY") { + testData.sqlContext.udf.register("groupFunction", (n:Int) => { n > 10 }) + + val result = + testData.sqlContext.sql("SELECT SUM(v) FROM groupData GROUP BY groupFunction(v)") + assert(result.count() === 2) + } + + test("UDFs everywhere") { + ctx.udf.register("groupFunction", (n:Int) => { n > 10 }) + ctx.udf.register("havingFilter", (n:Long) => { n > 2000 }) + ctx.udf.register("whereFilter", (n:Int) => { n < 150 }) + ctx.udf.register("timesHundred", (n:Long) => { n * 100 }) + + val result = + testData.sqlContext.sql( + """ + | SELECT timesHundred(SUM(v)) as v100 + | FROM groupData + | WHERE whereFilter(v) + | GROUP BY groupFunction(v) + | HAVING havingFilter(v100) + """.stripMargin) + assert(result.count() === 1) + } + test("struct UDF") { ctx.udf.register("returnStruct", (f1: String, f2: String) => FunctionResult(f1, f2))