From d9c42fbe7924d962b6fe018147d9f6c907b40e7b Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 4 Sep 2015 15:17:37 -0700 Subject: [PATCH] [SPARK-10176] [SQL] Show partially analyzed plans when checkAnswer fails to analyze This PR takes over https://github.com/apache/spark/pull/8389. This PR improves `checkAnswer` to print the partially analyzed plan in addition to the user friendly error message, in order to aid debugging failing tests. In doing so, I ran into a conflict with the various ways that we bring a SQLContext into the tests. Depending on the trait we refer to the current context as `sqlContext`, `_sqlContext`, `ctx` or `hiveContext` with access modifiers `public`, `protected` and `private` depending on the defining class. I propose we refactor as follows: 1. All tests should only refer to a `protected sqlContext` when testing general features, and `protected hiveContext` when it is a method that only exists on a `HiveContext`. 2. All tests should only import `testImplicits._` (i.e., don't import `TestHive.implicits._`) Author: Wenchen Fan Closes #8584 from cloud-fan/cleanupTests. --- .../spark/sql/catalyst/plans/PlanTest.scala | 1 - .../apache/spark/sql/CachedTableSuite.scala | 156 ++++++------- .../spark/sql/ColumnExpressionSuite.scala | 16 +- .../spark/sql/DataFrameAggregateSuite.scala | 4 +- .../spark/sql/DataFrameComplexTypeSuite.scala | 6 +- .../spark/sql/DataFrameImplicitsSuite.scala | 8 +- .../apache/spark/sql/DataFrameStatSuite.scala | 10 +- .../org/apache/spark/sql/DataFrameSuite.scala | 14 +- .../spark/sql/DataFrameTungstenSuite.scala | 6 +- .../spark/sql/ExtraStrategiesSuite.scala | 2 +- .../org/apache/spark/sql/JoinSuite.scala | 12 +- .../apache/spark/sql/ListTablesSuite.scala | 20 +- .../org/apache/spark/sql/QueryTest.scala | 27 ++- .../scala/org/apache/spark/sql/RowSuite.scala | 2 +- .../org/apache/spark/sql/SQLConfSuite.scala | 44 ++-- .../apache/spark/sql/SQLContextSuite.scala | 12 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 40 ++-- .../apache/spark/sql/SerializationSuite.scala | 2 +- .../spark/sql/StringFunctionsSuite.scala | 47 ++-- .../scala/org/apache/spark/sql/UDFSuite.scala | 42 ++-- .../spark/sql/UserDefinedTypeSuite.scala | 6 +- .../columnar/InMemoryColumnarQuerySuite.scala | 41 ++-- .../columnar/PartitionBatchPruningSuite.scala | 20 +- .../spark/sql/execution/ExchangeSuite.scala | 2 + .../spark/sql/execution/PlannerSuite.scala | 99 ++++---- .../execution/RowFormatConvertersSuite.scala | 16 +- .../spark/sql/execution/SortSuite.scala | 1 + .../spark/sql/execution/SparkPlanTest.scala | 27 +-- .../sql/execution/TungstenSortSuite.scala | 12 +- .../TungstenAggregationIteratorSuite.scala | 2 +- .../datasources/json/JsonSuite.scala | 214 +++++++++--------- .../datasources/json/TestJsonData.scala | 34 +-- .../parquet/ParquetCompatibilityTest.scala | 5 +- .../datasources/parquet/ParquetIOSuite.scala | 52 ++--- .../ParquetPartitionDiscoverySuite.scala | 2 +- .../parquet/ParquetQuerySuite.scala | 42 ++-- .../datasources/parquet/ParquetTest.scala | 9 +- .../execution/joins/BroadcastJoinSuite.scala | 10 +- .../execution/joins/HashedRelationSuite.scala | 6 +- .../sql/execution/joins/InnerJoinSuite.scala | 9 +- .../sql/execution/joins/OuterJoinSuite.scala | 8 +- .../sql/execution/joins/SemiJoinSuite.scala | 8 +- .../execution/metric/SQLMetricsSuite.scala | 24 +- .../sql/execution/ui/SQLListenerSuite.scala | 8 +- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 20 +- .../spark/sql/jdbc/JDBCWriteSuite.scala | 52 ++--- .../sources/CreateTableAsSelectSuite.scala | 2 - .../spark/sql/sources/DataSourceTest.scala | 4 +- .../spark/sql/sources/InsertSuite.scala | 3 +- .../sql/sources/PartitionedWriteSuite.scala | 8 +- .../spark/sql/sources/SaveLoadSuite.scala | 1 - .../apache/spark/sql/test/SQLTestData.scala | 52 ++--- .../apache/spark/sql/test/SQLTestUtils.scala | 41 ++-- .../spark/sql/test/SharedSQLContext.scala | 17 +- .../spark/sql/test/TestSQLContext.scala | 2 +- .../apache/spark/sql/hive/test/TestHive.scala | 7 +- .../spark/sql/hive/CachedTableSuite.scala | 20 +- .../spark/sql/hive/ErrorPositionSuite.scala | 8 +- .../hive/HiveDataFrameAnalyticsSuite.scala | 13 +- .../sql/hive/HiveDataFrameJoinSuite.scala | 6 +- .../sql/hive/HiveDataFrameWindowSuite.scala | 7 +- .../sql/hive/HiveMetastoreCatalogSuite.scala | 20 +- .../spark/sql/hive/HiveParquetSuite.scala | 12 +- .../spark/sql/hive/HiveSparkSubmitSuite.scala | 11 +- .../sql/hive/InsertIntoHiveTableSuite.scala | 35 ++- .../spark/sql/hive/ListTablesSuite.scala | 12 +- .../sql/hive/MetastoreDataSourcesSuite.scala | 13 +- .../spark/sql/hive/MultiDatabaseSuite.scala | 20 +- .../hive/ParquetHiveCompatibilitySuite.scala | 24 +- .../spark/sql/hive/QueryPartitionSuite.scala | 18 +- .../spark/sql/hive/StatisticsSuite.scala | 42 ++-- .../org/apache/spark/sql/hive/UDFSuite.scala | 16 +- .../execution/AggregationQuerySuite.scala | 19 +- .../hive/execution/HiveComparisonTest.scala | 4 +- .../sql/hive/execution/HiveExplainSuite.scala | 11 +- .../HiveOperatorQueryableSuite.scala | 8 +- .../sql/hive/execution/HivePlanTest.scala | 8 +- .../sql/hive/execution/HiveUDFSuite.scala | 54 ++--- .../sql/hive/execution/SQLQuerySuite.scala | 39 ++-- .../execution/ScriptTransformationSuite.scala | 17 +- .../hive/orc/OrcHadoopFsRelationSuite.scala | 7 +- .../hive/orc/OrcPartitionDiscoverySuite.scala | 23 +- .../spark/sql/hive/orc/OrcSourceSuite.scala | 10 +- .../apache/spark/sql/hive/orc/OrcTest.scala | 9 +- .../apache/spark/sql/hive/parquetSuites.scala | 16 +- .../CommitFailureTestRelationSuite.scala | 9 +- .../sources/JsonHadoopFsRelationSuite.scala | 12 +- .../ParquetHadoopFsRelationSuite.scala | 15 +- .../SimpleTextHadoopFsRelationSuite.scala | 4 +- .../sql/sources/hadoopFsRelationSuites.scala | 28 ++- 90 files changed, 908 insertions(+), 999 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 765c1e2dda..f76a903dcc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -26,7 +26,6 @@ import org.apache.spark.sql.catalyst.util._ * Provides helper methods for comparing plans. */ class PlanTest extends SparkFunSuite { - /** * Since attribute references are given globally unique ids during analysis, * we must normalize them to check if two different queries are identical. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index af7590c3d3..3a3541a842 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -34,7 +34,7 @@ class CachedTableSuite extends QueryTest with SharedSQLContext { import testImplicits._ def rddIdOf(tableName: String): Int = { - val executedPlan = ctx.table(tableName).queryExecution.executedPlan + val executedPlan = sqlContext.table(tableName).queryExecution.executedPlan executedPlan.collect { case InMemoryColumnarTableScan(_, _, relation) => relation.cachedColumnBuffers.id @@ -44,7 +44,7 @@ class CachedTableSuite extends QueryTest with SharedSQLContext { } def isMaterialized(rddId: Int): Boolean = { - ctx.sparkContext.env.blockManager.get(RDDBlockId(rddId, 0)).nonEmpty + sparkContext.env.blockManager.get(RDDBlockId(rddId, 0)).nonEmpty } test("withColumn doesn't invalidate cached dataframe") { @@ -69,41 +69,41 @@ class CachedTableSuite extends QueryTest with SharedSQLContext { test("cache temp table") { testData.select('key).registerTempTable("tempTable") assertCached(sql("SELECT COUNT(*) FROM tempTable"), 0) - ctx.cacheTable("tempTable") + sqlContext.cacheTable("tempTable") assertCached(sql("SELECT COUNT(*) FROM tempTable")) - ctx.uncacheTable("tempTable") + sqlContext.uncacheTable("tempTable") } test("unpersist an uncached table will not raise exception") { - assert(None == ctx.cacheManager.lookupCachedData(testData)) + assert(None == sqlContext.cacheManager.lookupCachedData(testData)) testData.unpersist(blocking = true) - assert(None == ctx.cacheManager.lookupCachedData(testData)) + assert(None == sqlContext.cacheManager.lookupCachedData(testData)) testData.unpersist(blocking = false) - assert(None == ctx.cacheManager.lookupCachedData(testData)) + assert(None == sqlContext.cacheManager.lookupCachedData(testData)) testData.persist() - assert(None != ctx.cacheManager.lookupCachedData(testData)) + assert(None != sqlContext.cacheManager.lookupCachedData(testData)) testData.unpersist(blocking = true) - assert(None == ctx.cacheManager.lookupCachedData(testData)) + assert(None == sqlContext.cacheManager.lookupCachedData(testData)) testData.unpersist(blocking = false) - assert(None == ctx.cacheManager.lookupCachedData(testData)) + assert(None == sqlContext.cacheManager.lookupCachedData(testData)) } test("cache table as select") { sql("CACHE TABLE tempTable AS SELECT key FROM testData") assertCached(sql("SELECT COUNT(*) FROM tempTable")) - ctx.uncacheTable("tempTable") + sqlContext.uncacheTable("tempTable") } test("uncaching temp table") { testData.select('key).registerTempTable("tempTable1") testData.select('key).registerTempTable("tempTable2") - ctx.cacheTable("tempTable1") + sqlContext.cacheTable("tempTable1") assertCached(sql("SELECT COUNT(*) FROM tempTable1")) assertCached(sql("SELECT COUNT(*) FROM tempTable2")) // Is this valid? - ctx.uncacheTable("tempTable2") + sqlContext.uncacheTable("tempTable2") // Should this be cached? assertCached(sql("SELECT COUNT(*) FROM tempTable1"), 0) @@ -111,103 +111,103 @@ class CachedTableSuite extends QueryTest with SharedSQLContext { test("too big for memory") { val data = "*" * 1000 - ctx.sparkContext.parallelize(1 to 200000, 1).map(_ => BigData(data)).toDF() + sparkContext.parallelize(1 to 200000, 1).map(_ => BigData(data)).toDF() .registerTempTable("bigData") - ctx.table("bigData").persist(StorageLevel.MEMORY_AND_DISK) - assert(ctx.table("bigData").count() === 200000L) - ctx.table("bigData").unpersist(blocking = true) + sqlContext.table("bigData").persist(StorageLevel.MEMORY_AND_DISK) + assert(sqlContext.table("bigData").count() === 200000L) + sqlContext.table("bigData").unpersist(blocking = true) } test("calling .cache() should use in-memory columnar caching") { - ctx.table("testData").cache() - assertCached(ctx.table("testData")) - ctx.table("testData").unpersist(blocking = true) + sqlContext.table("testData").cache() + assertCached(sqlContext.table("testData")) + sqlContext.table("testData").unpersist(blocking = true) } test("calling .unpersist() should drop in-memory columnar cache") { - ctx.table("testData").cache() - ctx.table("testData").count() - ctx.table("testData").unpersist(blocking = true) - assertCached(ctx.table("testData"), 0) + sqlContext.table("testData").cache() + sqlContext.table("testData").count() + sqlContext.table("testData").unpersist(blocking = true) + assertCached(sqlContext.table("testData"), 0) } test("isCached") { - ctx.cacheTable("testData") + sqlContext.cacheTable("testData") - assertCached(ctx.table("testData")) - assert(ctx.table("testData").queryExecution.withCachedData match { + assertCached(sqlContext.table("testData")) + assert(sqlContext.table("testData").queryExecution.withCachedData match { case _: InMemoryRelation => true case _ => false }) - ctx.uncacheTable("testData") - assert(!ctx.isCached("testData")) - assert(ctx.table("testData").queryExecution.withCachedData match { + sqlContext.uncacheTable("testData") + assert(!sqlContext.isCached("testData")) + assert(sqlContext.table("testData").queryExecution.withCachedData match { case _: InMemoryRelation => false case _ => true }) } test("SPARK-1669: cacheTable should be idempotent") { - assume(!ctx.table("testData").logicalPlan.isInstanceOf[InMemoryRelation]) + assume(!sqlContext.table("testData").logicalPlan.isInstanceOf[InMemoryRelation]) - ctx.cacheTable("testData") - assertCached(ctx.table("testData")) + sqlContext.cacheTable("testData") + assertCached(sqlContext.table("testData")) assertResult(1, "InMemoryRelation not found, testData should have been cached") { - ctx.table("testData").queryExecution.withCachedData.collect { + sqlContext.table("testData").queryExecution.withCachedData.collect { case r: InMemoryRelation => r }.size } - ctx.cacheTable("testData") + sqlContext.cacheTable("testData") assertResult(0, "Double InMemoryRelations found, cacheTable() is not idempotent") { - ctx.table("testData").queryExecution.withCachedData.collect { + sqlContext.table("testData").queryExecution.withCachedData.collect { case r @ InMemoryRelation(_, _, _, _, _: InMemoryColumnarTableScan, _) => r }.size } - ctx.uncacheTable("testData") + sqlContext.uncacheTable("testData") } test("read from cached table and uncache") { - ctx.cacheTable("testData") - checkAnswer(ctx.table("testData"), testData.collect().toSeq) - assertCached(ctx.table("testData")) + sqlContext.cacheTable("testData") + checkAnswer(sqlContext.table("testData"), testData.collect().toSeq) + assertCached(sqlContext.table("testData")) - ctx.uncacheTable("testData") - checkAnswer(ctx.table("testData"), testData.collect().toSeq) - assertCached(ctx.table("testData"), 0) + sqlContext.uncacheTable("testData") + checkAnswer(sqlContext.table("testData"), testData.collect().toSeq) + assertCached(sqlContext.table("testData"), 0) } test("correct error on uncache of non-cached table") { intercept[IllegalArgumentException] { - ctx.uncacheTable("testData") + sqlContext.uncacheTable("testData") } } test("SELECT star from cached table") { sql("SELECT * FROM testData").registerTempTable("selectStar") - ctx.cacheTable("selectStar") + sqlContext.cacheTable("selectStar") checkAnswer( sql("SELECT * FROM selectStar WHERE key = 1"), Seq(Row(1, "1"))) - ctx.uncacheTable("selectStar") + sqlContext.uncacheTable("selectStar") } test("Self-join cached") { val unCachedAnswer = sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key").collect() - ctx.cacheTable("testData") + sqlContext.cacheTable("testData") checkAnswer( sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key"), unCachedAnswer.toSeq) - ctx.uncacheTable("testData") + sqlContext.uncacheTable("testData") } test("'CACHE TABLE' and 'UNCACHE TABLE' SQL statement") { sql("CACHE TABLE testData") - assertCached(ctx.table("testData")) + assertCached(sqlContext.table("testData")) val rddId = rddIdOf("testData") assert( @@ -215,7 +215,7 @@ class CachedTableSuite extends QueryTest with SharedSQLContext { "Eagerly cached in-memory table should have already been materialized") sql("UNCACHE TABLE testData") - assert(!ctx.isCached("testData"), "Table 'testData' should not be cached") + assert(!sqlContext.isCached("testData"), "Table 'testData' should not be cached") eventually(timeout(10 seconds)) { assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") @@ -224,14 +224,14 @@ class CachedTableSuite extends QueryTest with SharedSQLContext { test("CACHE TABLE tableName AS SELECT * FROM anotherTable") { sql("CACHE TABLE testCacheTable AS SELECT * FROM testData") - assertCached(ctx.table("testCacheTable")) + assertCached(sqlContext.table("testCacheTable")) val rddId = rddIdOf("testCacheTable") assert( isMaterialized(rddId), "Eagerly cached in-memory table should have already been materialized") - ctx.uncacheTable("testCacheTable") + sqlContext.uncacheTable("testCacheTable") eventually(timeout(10 seconds)) { assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") } @@ -239,14 +239,14 @@ class CachedTableSuite extends QueryTest with SharedSQLContext { test("CACHE TABLE tableName AS SELECT ...") { sql("CACHE TABLE testCacheTable AS SELECT key FROM testData LIMIT 10") - assertCached(ctx.table("testCacheTable")) + assertCached(sqlContext.table("testCacheTable")) val rddId = rddIdOf("testCacheTable") assert( isMaterialized(rddId), "Eagerly cached in-memory table should have already been materialized") - ctx.uncacheTable("testCacheTable") + sqlContext.uncacheTable("testCacheTable") eventually(timeout(10 seconds)) { assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") } @@ -254,7 +254,7 @@ class CachedTableSuite extends QueryTest with SharedSQLContext { test("CACHE LAZY TABLE tableName") { sql("CACHE LAZY TABLE testData") - assertCached(ctx.table("testData")) + assertCached(sqlContext.table("testData")) val rddId = rddIdOf("testData") assert( @@ -266,7 +266,7 @@ class CachedTableSuite extends QueryTest with SharedSQLContext { isMaterialized(rddId), "Lazily cached in-memory table should have been materialized") - ctx.uncacheTable("testData") + sqlContext.uncacheTable("testData") eventually(timeout(10 seconds)) { assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") } @@ -274,7 +274,7 @@ class CachedTableSuite extends QueryTest with SharedSQLContext { test("InMemoryRelation statistics") { sql("CACHE TABLE testData") - ctx.table("testData").queryExecution.withCachedData.collect { + sqlContext.table("testData").queryExecution.withCachedData.collect { case cached: InMemoryRelation => val actualSizeInBytes = (1 to 100).map(i => INT.defaultSize + i.toString.length + 4).sum assert(cached.statistics.sizeInBytes === actualSizeInBytes) @@ -283,46 +283,48 @@ class CachedTableSuite extends QueryTest with SharedSQLContext { test("Drops temporary table") { testData.select('key).registerTempTable("t1") - ctx.table("t1") - ctx.dropTempTable("t1") - assert(intercept[RuntimeException](ctx.table("t1")).getMessage.startsWith("Table Not Found")) + sqlContext.table("t1") + sqlContext.dropTempTable("t1") + assert( + intercept[RuntimeException](sqlContext.table("t1")).getMessage.startsWith("Table Not Found")) } test("Drops cached temporary table") { testData.select('key).registerTempTable("t1") testData.select('key).registerTempTable("t2") - ctx.cacheTable("t1") + sqlContext.cacheTable("t1") - assert(ctx.isCached("t1")) - assert(ctx.isCached("t2")) + assert(sqlContext.isCached("t1")) + assert(sqlContext.isCached("t2")) - ctx.dropTempTable("t1") - assert(intercept[RuntimeException](ctx.table("t1")).getMessage.startsWith("Table Not Found")) - assert(!ctx.isCached("t2")) + sqlContext.dropTempTable("t1") + assert( + intercept[RuntimeException](sqlContext.table("t1")).getMessage.startsWith("Table Not Found")) + assert(!sqlContext.isCached("t2")) } test("Clear all cache") { sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1") sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2") - ctx.cacheTable("t1") - ctx.cacheTable("t2") - ctx.clearCache() - assert(ctx.cacheManager.isEmpty) + sqlContext.cacheTable("t1") + sqlContext.cacheTable("t2") + sqlContext.clearCache() + assert(sqlContext.cacheManager.isEmpty) sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1") sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2") - ctx.cacheTable("t1") - ctx.cacheTable("t2") + sqlContext.cacheTable("t1") + sqlContext.cacheTable("t2") sql("Clear CACHE") - assert(ctx.cacheManager.isEmpty) + assert(sqlContext.cacheManager.isEmpty) } test("Clear accumulators when uncacheTable to prevent memory leaking") { sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1") sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2") - ctx.cacheTable("t1") - ctx.cacheTable("t2") + sqlContext.cacheTable("t1") + sqlContext.cacheTable("t2") sql("SELECT * FROM t1").count() sql("SELECT * FROM t2").count() @@ -331,8 +333,8 @@ class CachedTableSuite extends QueryTest with SharedSQLContext { Accumulators.synchronized { val accsSize = Accumulators.originals.size - ctx.uncacheTable("t1") - ctx.uncacheTable("t2") + sqlContext.uncacheTable("t1") + sqlContext.uncacheTable("t2") assert((accsSize - 2) == Accumulators.originals.size) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 37738ec5b3..4e988f074b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -29,7 +29,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { import testImplicits._ private lazy val booleanData = { - ctx.createDataFrame(ctx.sparkContext.parallelize( + sqlContext.createDataFrame(sparkContext.parallelize( Row(false, false) :: Row(false, true) :: Row(true, false) :: @@ -286,7 +286,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { } test("isNaN") { - val testData = ctx.createDataFrame(ctx.sparkContext.parallelize( + val testData = sqlContext.createDataFrame(sparkContext.parallelize( Row(Double.NaN, Float.NaN) :: Row(math.log(-1), math.log(-3).toFloat) :: Row(null, null) :: @@ -307,7 +307,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { } test("nanvl") { - val testData = ctx.createDataFrame(ctx.sparkContext.parallelize( + val testData = sqlContext.createDataFrame(sparkContext.parallelize( Row(null, 3.0, Double.NaN, Double.PositiveInfinity, 1.0f, 4) :: Nil), StructType(Seq(StructField("a", DoubleType), StructField("b", DoubleType), StructField("c", DoubleType), StructField("d", DoubleType), @@ -350,7 +350,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { } test("!==") { - val nullData = ctx.createDataFrame(ctx.sparkContext.parallelize( + val nullData = sqlContext.createDataFrame(sparkContext.parallelize( Row(1, 1) :: Row(1, 2) :: Row(1, null) :: @@ -411,7 +411,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { } test("between") { - val testData = ctx.sparkContext.parallelize( + val testData = sparkContext.parallelize( (0, 1, 2) :: (1, 2, 3) :: (2, 1, 0) :: @@ -556,7 +556,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { test("monotonicallyIncreasingId") { // Make sure we have 2 partitions, each with 2 records. - val df = ctx.sparkContext.parallelize(Seq[Int](), 2).mapPartitions { _ => + val df = sparkContext.parallelize(Seq[Int](), 2).mapPartitions { _ => Iterator(Tuple1(1), Tuple1(2)) }.toDF("a") checkAnswer( @@ -567,7 +567,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { test("sparkPartitionId") { // Make sure we have 2 partitions, each with 2 records. - val df = ctx.sparkContext.parallelize(Seq[Int](), 2).mapPartitions { _ => + val df = sparkContext.parallelize(Seq[Int](), 2).mapPartitions { _ => Iterator(Tuple1(1), Tuple1(2)) }.toDF("a") checkAnswer( @@ -578,7 +578,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { test("InputFileName") { withTempPath { dir => - val data = sqlContext.sparkContext.parallelize(0 to 10).toDF("id") + val data = sparkContext.parallelize(0 to 10).toDF("id") data.write.parquet(dir.getCanonicalPath) val answer = sqlContext.read.parquet(dir.getCanonicalPath).select(inputFileName()) .head.getString(0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 72cf7aab0b..c0950b09b1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -66,12 +66,12 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { Seq(Row(1, 3), Row(2, 3), Row(3, 3)) ) - ctx.conf.setConf(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS, false) + sqlContext.conf.setConf(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS, false) checkAnswer( testData2.groupBy("a").agg(sum($"b")), Seq(Row(3), Row(3), Row(3)) ) - ctx.conf.setConf(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS, true) + sqlContext.conf.setConf(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS, true) } test("agg without groups") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala index 3c359dd840..09f7b50767 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala @@ -28,19 +28,19 @@ class DataFrameComplexTypeSuite extends QueryTest with SharedSQLContext { test("UDF on struct") { val f = udf((a: String) => a) - val df = sqlContext.sparkContext.parallelize(Seq((1, 1))).toDF("a", "b") + val df = sparkContext.parallelize(Seq((1, 1))).toDF("a", "b") df.select(struct($"a").as("s")).select(f($"s.a")).collect() } test("UDF on named_struct") { val f = udf((a: String) => a) - val df = sqlContext.sparkContext.parallelize(Seq((1, 1))).toDF("a", "b") + val df = sparkContext.parallelize(Seq((1, 1))).toDF("a", "b") df.selectExpr("named_struct('a', a) s").select(f($"s.a")).collect() } test("UDF on array") { val f = udf((a: String) => a) - val df = sqlContext.sparkContext.parallelize(Seq((1, 1))).toDF("a", "b") + val df = sparkContext.parallelize(Seq((1, 1))).toDF("a", "b") df.select(array($"a").as("s")).select(f(expr("s[0]"))).collect() } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala index e5d7d63441..094efbaead 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala @@ -24,7 +24,7 @@ class DataFrameImplicitsSuite extends QueryTest with SharedSQLContext { test("RDD of tuples") { checkAnswer( - ctx.sparkContext.parallelize(1 to 10).map(i => (i, i.toString)).toDF("intCol", "strCol"), + sparkContext.parallelize(1 to 10).map(i => (i, i.toString)).toDF("intCol", "strCol"), (1 to 10).map(i => Row(i, i.toString))) } @@ -36,19 +36,19 @@ class DataFrameImplicitsSuite extends QueryTest with SharedSQLContext { test("RDD[Int]") { checkAnswer( - ctx.sparkContext.parallelize(1 to 10).toDF("intCol"), + sparkContext.parallelize(1 to 10).toDF("intCol"), (1 to 10).map(i => Row(i))) } test("RDD[Long]") { checkAnswer( - ctx.sparkContext.parallelize(1L to 10L).toDF("longCol"), + sparkContext.parallelize(1L to 10L).toDF("longCol"), (1L to 10L).map(i => Row(i))) } test("RDD[String]") { checkAnswer( - ctx.sparkContext.parallelize(1 to 10).map(_.toString).toDF("stringCol"), + sparkContext.parallelize(1 to 10).map(_.toString).toDF("stringCol"), (1 to 10).map(i => Row(i.toString))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 28bdd6f83b..6524abcf5e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -29,7 +29,7 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { test("sample with replacement") { val n = 100 - val data = ctx.sparkContext.parallelize(1 to n, 2).toDF("id") + val data = sparkContext.parallelize(1 to n, 2).toDF("id") checkAnswer( data.sample(withReplacement = true, 0.05, seed = 13), Seq(5, 10, 52, 73).map(Row(_)) @@ -38,7 +38,7 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { test("sample without replacement") { val n = 100 - val data = ctx.sparkContext.parallelize(1 to n, 2).toDF("id") + val data = sparkContext.parallelize(1 to n, 2).toDF("id") checkAnswer( data.sample(withReplacement = false, 0.05, seed = 13), Seq(16, 23, 88, 100).map(Row(_)) @@ -47,7 +47,7 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { test("randomSplit") { val n = 600 - val data = ctx.sparkContext.parallelize(1 to n, 2).toDF("id") + val data = sparkContext.parallelize(1 to n, 2).toDF("id") for (seed <- 1 to 5) { val splits = data.randomSplit(Array[Double](1, 2, 3), seed) assert(splits.length == 3, "wrong number of splits") @@ -164,7 +164,7 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { } test("Frequent Items 2") { - val rows = ctx.sparkContext.parallelize(Seq.empty[Int], 4) + val rows = sparkContext.parallelize(Seq.empty[Int], 4) // this is a regression test, where when merging partitions, we omitted values with higher // counts than those that existed in the map when the map was full. This test should also fail // if anything like SPARK-9614 is observed once again @@ -182,7 +182,7 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { } test("sampleBy") { - val df = ctx.range(0, 100).select((col("id") % 3).as("key")) + val df = sqlContext.range(0, 100).select((col("id") % 3).as("key")) val sampled = df.stat.sampleBy("key", Map(0 -> 0.1, 1 -> 0.2), 0L) checkAnswer( sampled.groupBy("key").count().orderBy("key"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index a4871e247c..b5b9f11785 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -345,7 +345,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } test("replace column using withColumn") { - val df2 = sqlContext.sparkContext.parallelize(Array(1, 2, 3)).toDF("x") + val df2 = sparkContext.parallelize(Array(1, 2, 3)).toDF("x") val df3 = df2.withColumn("x", df2("x") + 1) checkAnswer( df3.select("x"), @@ -506,7 +506,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { test("showString: truncate = [true, false]") { val longString = Array.fill(21)("1").mkString - val df = sqlContext.sparkContext.parallelize(Seq("1", longString)).toDF() + val df = sparkContext.parallelize(Seq("1", longString)).toDF() val expectedAnswerForFalse = """+---------------------+ ||_1 | |+---------------------+ @@ -596,7 +596,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } test("createDataFrame(RDD[Row], StructType) should convert UDTs (SPARK-6672)") { - val rowRDD = sqlContext.sparkContext.parallelize(Seq(Row(new ExamplePoint(1.0, 2.0)))) + val rowRDD = sparkContext.parallelize(Seq(Row(new ExamplePoint(1.0, 2.0)))) val schema = StructType(Array(StructField("point", new ExamplePointUDT(), false))) val df = sqlContext.createDataFrame(rowRDD, schema) df.rdd.collect() @@ -619,14 +619,14 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } test("SPARK-7551: support backticks for DataFrame attribute resolution") { - val df = sqlContext.read.json(sqlContext.sparkContext.makeRDD( + val df = sqlContext.read.json(sparkContext.makeRDD( """{"a.b": {"c": {"d..e": {"f": 1}}}}""" :: Nil)) checkAnswer( df.select(df("`a.b`.c.`d..e`.`f`")), Row(1) ) - val df2 = sqlContext.read.json(sqlContext.sparkContext.makeRDD( + val df2 = sqlContext.read.json(sparkContext.makeRDD( """{"a b": {"c": {"d e": {"f": 1}}}}""" :: Nil)) checkAnswer( df2.select(df2("`a b`.c.d e.f")), @@ -646,7 +646,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } test("SPARK-7324 dropDuplicates") { - val testData = sqlContext.sparkContext.parallelize( + val testData = sparkContext.parallelize( (2, 1, 2) :: (1, 1, 1) :: (1, 2, 1) :: (2, 1, 2) :: (2, 2, 2) :: (2, 2, 1) :: @@ -869,7 +869,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } test("SPARK-9323: DataFrame.orderBy should support nested column name") { - val df = sqlContext.read.json(sqlContext.sparkContext.makeRDD( + val df = sqlContext.read.json(sparkContext.makeRDD( """{"a": {"b": 1}}""" :: Nil)) checkAnswer(df.orderBy("a.b"), Row(Row(1))) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala index 77907e9136..7ae12a7895 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala @@ -32,7 +32,7 @@ class DataFrameTungstenSuite extends QueryTest with SharedSQLContext { test("test simple types") { withSQLConf(SQLConf.UNSAFE_ENABLED.key -> "true") { - val df = sqlContext.sparkContext.parallelize(Seq((1, 2))).toDF("a", "b") + val df = sparkContext.parallelize(Seq((1, 2))).toDF("a", "b") assert(df.select(struct("a", "b")).first().getStruct(0) === Row(1, 2)) } } @@ -40,7 +40,7 @@ class DataFrameTungstenSuite extends QueryTest with SharedSQLContext { test("test struct type") { withSQLConf(SQLConf.UNSAFE_ENABLED.key -> "true") { val struct = Row(1, 2L, 3.0F, 3.0) - val data = sqlContext.sparkContext.parallelize(Seq(Row(1, struct))) + val data = sparkContext.parallelize(Seq(Row(1, struct))) val schema = new StructType() .add("a", IntegerType) @@ -60,7 +60,7 @@ class DataFrameTungstenSuite extends QueryTest with SharedSQLContext { withSQLConf(SQLConf.UNSAFE_ENABLED.key -> "true") { val innerStruct = Row(1, "abcd") val outerStruct = Row(1, 2L, 3.0F, 3.0, innerStruct, "efg") - val data = sqlContext.sparkContext.parallelize(Seq(Row(1, outerStruct))) + val data = sparkContext.parallelize(Seq(Row(1, outerStruct))) val schema = new StructType() .add("a", IntegerType) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala index 8d2f45d703..78a98798ef 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala @@ -52,7 +52,7 @@ class ExtraStrategiesSuite extends QueryTest with SharedSQLContext { try { sqlContext.experimental.extraStrategies = TestStrategy :: Nil - val df = sqlContext.sparkContext.parallelize(Seq(("so slow", 1))).toDF("a", "b") + val df = sparkContext.parallelize(Seq(("so slow", 1))).toDF("a", "b") checkAnswer( df.select("a"), Row("so fast")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index f5c5046a8e..b05435bad5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -31,7 +31,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { val x = testData2.as("x") val y = testData2.as("y") val join = x.join(y, $"x.a" === $"y.a", "inner").queryExecution.optimizedPlan - val planned = ctx.planner.EquiJoinSelection(join) + val planned = sqlContext.planner.EquiJoinSelection(join) assert(planned.size === 1) } @@ -59,7 +59,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { } test("join operator selection") { - ctx.cacheManager.clearCache() + sqlContext.cacheManager.clearCache() Seq( ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]), @@ -118,7 +118,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { } test("broadcasted hash join operator selection") { - ctx.cacheManager.clearCache() + sqlContext.cacheManager.clearCache() sql("CACHE TABLE testData") for (sortMergeJoinEnabled <- Seq(true, false)) { withClue(s"sortMergeJoinEnabled=$sortMergeJoinEnabled") { @@ -138,7 +138,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { } test("broadcasted hash outer join operator selection") { - ctx.cacheManager.clearCache() + sqlContext.cacheManager.clearCache() sql("CACHE TABLE testData") withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") { Seq( @@ -167,7 +167,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { val x = testData2.as("x") val y = testData2.as("y") val join = x.join(y, ($"x.a" === $"y.a") && ($"x.b" === $"y.b")).queryExecution.optimizedPlan - val planned = ctx.planner.EquiJoinSelection(join) + val planned = sqlContext.planner.EquiJoinSelection(join) assert(planned.size === 1) } @@ -442,7 +442,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { } test("broadcasted left semi join operator selection") { - ctx.cacheManager.clearCache() + sqlContext.cacheManager.clearCache() sql("CACHE TABLE testData") withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1000000000") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala index babf8835d2..eab0fbb196 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala @@ -32,33 +32,33 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContex } after { - ctx.catalog.unregisterTable(Seq("ListTablesSuiteTable")) + sqlContext.catalog.unregisterTable(Seq("ListTablesSuiteTable")) } test("get all tables") { checkAnswer( - ctx.tables().filter("tableName = 'ListTablesSuiteTable'"), + sqlContext.tables().filter("tableName = 'ListTablesSuiteTable'"), Row("ListTablesSuiteTable", true)) checkAnswer( sql("SHOW tables").filter("tableName = 'ListTablesSuiteTable'"), Row("ListTablesSuiteTable", true)) - ctx.catalog.unregisterTable(Seq("ListTablesSuiteTable")) - assert(ctx.tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0) + sqlContext.catalog.unregisterTable(Seq("ListTablesSuiteTable")) + assert(sqlContext.tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0) } test("getting all Tables with a database name has no impact on returned table names") { checkAnswer( - ctx.tables("DB").filter("tableName = 'ListTablesSuiteTable'"), + sqlContext.tables("DB").filter("tableName = 'ListTablesSuiteTable'"), Row("ListTablesSuiteTable", true)) checkAnswer( sql("show TABLES in DB").filter("tableName = 'ListTablesSuiteTable'"), Row("ListTablesSuiteTable", true)) - ctx.catalog.unregisterTable(Seq("ListTablesSuiteTable")) - assert(ctx.tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0) + sqlContext.catalog.unregisterTable(Seq("ListTablesSuiteTable")) + assert(sqlContext.tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0) } test("query the returned DataFrame of tables") { @@ -66,7 +66,7 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContex StructField("tableName", StringType, false) :: StructField("isTemporary", BooleanType, false) :: Nil) - Seq(ctx.tables(), sql("SHOW TABLes")).foreach { + Seq(sqlContext.tables(), sql("SHOW TABLes")).foreach { case tableDF => assert(expectedSchema === tableDF.schema) @@ -77,9 +77,9 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContex Row(true, "ListTablesSuiteTable") ) checkAnswer( - ctx.tables().filter("tableName = 'tables'").select("tableName", "isTemporary"), + sqlContext.tables().filter("tableName = 'tables'").select("tableName", "isTemporary"), Row("tables", true)) - ctx.dropTempTable("tables") + sqlContext.dropTempTable("tables") } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 3649c2a97b..cada03e9ac 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -25,7 +25,9 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.columnar.InMemoryRelation -class QueryTest extends PlanTest { +abstract class QueryTest extends PlanTest { + + protected def sqlContext: SQLContext // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*) TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) @@ -56,18 +58,33 @@ class QueryTest extends PlanTest { * @param df the [[DataFrame]] to be executed * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. */ - protected def checkAnswer(df: DataFrame, expectedAnswer: Seq[Row]): Unit = { - QueryTest.checkAnswer(df, expectedAnswer) match { + protected def checkAnswer(df: => DataFrame, expectedAnswer: Seq[Row]): Unit = { + val analyzedDF = try df catch { + case ae: AnalysisException => + val currentValue = sqlContext.conf.dataFrameEagerAnalysis + sqlContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, false) + val partiallyAnalzyedPlan = df.queryExecution.analyzed + sqlContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, currentValue) + fail( + s""" + |Failed to analyze query: $ae + |$partiallyAnalzyedPlan + | + |${stackTraceToString(ae)} + |""".stripMargin) + } + + QueryTest.checkAnswer(analyzedDF, expectedAnswer) match { case Some(errorMessage) => fail(errorMessage) case None => } } - protected def checkAnswer(df: DataFrame, expectedAnswer: Row): Unit = { + protected def checkAnswer(df: => DataFrame, expectedAnswer: Row): Unit = { checkAnswer(df, Seq(expectedAnswer)) } - protected def checkAnswer(df: DataFrame, expectedAnswer: DataFrame): Unit = { + protected def checkAnswer(df: => DataFrame, expectedAnswer: DataFrame): Unit = { checkAnswer(df, expectedAnswer.collect()) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala index 77ccd6f775..3ba14d7602 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala @@ -57,7 +57,7 @@ class RowSuite extends SparkFunSuite with SharedSQLContext { test("serialize w/ kryo") { val row = Seq((1, Seq(1), Map(1 -> 1), BigDecimal(1))).toDF().first() - val serializer = new SparkSqlSerializer(ctx.sparkContext.getConf) + val serializer = new SparkSqlSerializer(sparkContext.getConf) val instance = serializer.newInstance() val ser = instance.serialize(row) val de = instance.deserialize(ser).asInstanceOf[Row] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala index 7699adadd9..c35b31c96d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala @@ -27,58 +27,58 @@ class SQLConfSuite extends QueryTest with SharedSQLContext { test("propagate from spark conf") { // We create a new context here to avoid order dependence with other tests that might call // clear(). - val newContext = new SQLContext(ctx.sparkContext) + val newContext = new SQLContext(sparkContext) assert(newContext.getConf("spark.sql.testkey", "false") === "true") } test("programmatic ways of basic setting and getting") { - ctx.conf.clear() - assert(ctx.getAllConfs.size === 0) + sqlContext.conf.clear() + assert(sqlContext.getAllConfs.size === 0) - ctx.setConf(testKey, testVal) - assert(ctx.getConf(testKey) === testVal) - assert(ctx.getConf(testKey, testVal + "_") === testVal) - assert(ctx.getAllConfs.contains(testKey)) + sqlContext.setConf(testKey, testVal) + assert(sqlContext.getConf(testKey) === testVal) + assert(sqlContext.getConf(testKey, testVal + "_") === testVal) + assert(sqlContext.getAllConfs.contains(testKey)) // Tests SQLConf as accessed from a SQLContext is mutable after // the latter is initialized, unlike SparkConf inside a SparkContext. - assert(ctx.getConf(testKey) == testVal) - assert(ctx.getConf(testKey, testVal + "_") === testVal) - assert(ctx.getAllConfs.contains(testKey)) + assert(sqlContext.getConf(testKey) == testVal) + assert(sqlContext.getConf(testKey, testVal + "_") === testVal) + assert(sqlContext.getAllConfs.contains(testKey)) - ctx.conf.clear() + sqlContext.conf.clear() } test("parse SQL set commands") { - ctx.conf.clear() + sqlContext.conf.clear() sql(s"set $testKey=$testVal") - assert(ctx.getConf(testKey, testVal + "_") === testVal) - assert(ctx.getConf(testKey, testVal + "_") === testVal) + assert(sqlContext.getConf(testKey, testVal + "_") === testVal) + assert(sqlContext.getConf(testKey, testVal + "_") === testVal) sql("set some.property=20") - assert(ctx.getConf("some.property", "0") === "20") + assert(sqlContext.getConf("some.property", "0") === "20") sql("set some.property = 40") - assert(ctx.getConf("some.property", "0") === "40") + assert(sqlContext.getConf("some.property", "0") === "40") val key = "spark.sql.key" val vs = "val0,val_1,val2.3,my_table" sql(s"set $key=$vs") - assert(ctx.getConf(key, "0") === vs) + assert(sqlContext.getConf(key, "0") === vs) sql(s"set $key=") - assert(ctx.getConf(key, "0") === "") + assert(sqlContext.getConf(key, "0") === "") - ctx.conf.clear() + sqlContext.conf.clear() } test("deprecated property") { - ctx.conf.clear() + sqlContext.conf.clear() sql(s"set ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}=10") - assert(ctx.conf.numShufflePartitions === 10) + assert(sqlContext.conf.numShufflePartitions === 10) } test("invalid conf value") { - ctx.conf.clear() + sqlContext.conf.clear() val e = intercept[IllegalArgumentException] { sql(s"set ${SQLConf.CASE_SENSITIVE.key}=10") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala index 007be12950..dd88ae3700 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala @@ -24,7 +24,7 @@ class SQLContextSuite extends SparkFunSuite with SharedSQLContext { override def afterAll(): Unit = { try { - SQLContext.setLastInstantiatedContext(ctx) + SQLContext.setLastInstantiatedContext(sqlContext) } finally { super.afterAll() } @@ -32,18 +32,18 @@ class SQLContextSuite extends SparkFunSuite with SharedSQLContext { test("getOrCreate instantiates SQLContext") { SQLContext.clearLastInstantiatedContext() - val sqlContext = SQLContext.getOrCreate(ctx.sparkContext) + val sqlContext = SQLContext.getOrCreate(sparkContext) assert(sqlContext != null, "SQLContext.getOrCreate returned null") - assert(SQLContext.getOrCreate(ctx.sparkContext).eq(sqlContext), + assert(SQLContext.getOrCreate(sparkContext).eq(sqlContext), "SQLContext created by SQLContext.getOrCreate not returned by SQLContext.getOrCreate") } test("getOrCreate gets last explicitly instantiated SQLContext") { SQLContext.clearLastInstantiatedContext() - val sqlContext = new SQLContext(ctx.sparkContext) - assert(SQLContext.getOrCreate(ctx.sparkContext) != null, + val sqlContext = new SQLContext(sparkContext) + assert(SQLContext.getOrCreate(sparkContext) != null, "SQLContext.getOrCreate after explicitly created SQLContext returned null") - assert(SQLContext.getOrCreate(ctx.sparkContext).eq(sqlContext), + assert(SQLContext.getOrCreate(sparkContext).eq(sqlContext), "SQLContext.getOrCreate after explicitly created SQLContext did not return the context") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 0ef25fe0fa..05f2000459 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -147,14 +147,14 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SQL Dialect Switching to a new SQL parser") { - val newContext = new SQLContext(sqlContext.sparkContext) + val newContext = new SQLContext(sparkContext) newContext.setConf("spark.sql.dialect", classOf[MyDialect].getCanonicalName()) assert(newContext.getSQLDialect().getClass === classOf[MyDialect]) assert(newContext.sql("SELECT 1").collect() === Array(Row(1))) } test("SQL Dialect Switch to an invalid parser with alias") { - val newContext = new SQLContext(sqlContext.sparkContext) + val newContext = new SQLContext(sparkContext) newContext.sql("SET spark.sql.dialect=MyTestClass") intercept[DialectException] { newContext.sql("SELECT 1") @@ -196,7 +196,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("grouping on nested fields") { - sqlContext.read.json(sqlContext.sparkContext.parallelize( + sqlContext.read.json(sparkContext.parallelize( """{"nested": {"attribute": 1}, "value": 2}""" :: Nil)) .registerTempTable("rows") @@ -215,7 +215,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("SPARK-6201 IN type conversion") { sqlContext.read.json( - sqlContext.sparkContext.parallelize( + sparkContext.parallelize( Seq("{\"a\": \"1\"}}", "{\"a\": \"2\"}}", "{\"a\": \"3\"}}"))) .registerTempTable("d") @@ -1342,7 +1342,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-3483 Special chars in column names") { - val data = sqlContext.sparkContext.parallelize( + val data = sparkContext.parallelize( Seq("""{"key?number1": "value1", "key.number2": "value2"}""")) sqlContext.read.json(data).registerTempTable("records") sql("SELECT `key?number1`, `key.number2` FROM records") @@ -1385,13 +1385,13 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-4322 Grouping field with struct field as sub expression") { - sqlContext.read.json(sqlContext.sparkContext.makeRDD("""{"a": {"b": [{"c": 1}]}}""" :: Nil)) + sqlContext.read.json(sparkContext.makeRDD("""{"a": {"b": [{"c": 1}]}}""" :: Nil)) .registerTempTable("data") checkAnswer(sql("SELECT a.b[0].c FROM data GROUP BY a.b[0].c"), Row(1)) sqlContext.dropTempTable("data") sqlContext.read.json( - sqlContext.sparkContext.makeRDD("""{"a": {"b": 1}}""" :: Nil)).registerTempTable("data") + sparkContext.makeRDD("""{"a": {"b": 1}}""" :: Nil)).registerTempTable("data") checkAnswer(sql("SELECT a.b + 1 FROM data GROUP BY a.b + 1"), Row(2)) sqlContext.dropTempTable("data") } @@ -1412,10 +1412,10 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("Supporting relational operator '<=>' in Spark SQL") { val nullCheckData1 = TestData(1, "1") :: TestData(2, null) :: Nil - val rdd1 = sqlContext.sparkContext.parallelize((0 to 1).map(i => nullCheckData1(i))) + val rdd1 = sparkContext.parallelize((0 to 1).map(i => nullCheckData1(i))) rdd1.toDF().registerTempTable("nulldata1") val nullCheckData2 = TestData(1, "1") :: TestData(2, null) :: Nil - val rdd2 = sqlContext.sparkContext.parallelize((0 to 1).map(i => nullCheckData2(i))) + val rdd2 = sparkContext.parallelize((0 to 1).map(i => nullCheckData2(i))) rdd2.toDF().registerTempTable("nulldata2") checkAnswer(sql("SELECT nulldata1.key FROM nulldata1 join " + "nulldata2 on nulldata1.value <=> nulldata2.value"), @@ -1424,7 +1424,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("Multi-column COUNT(DISTINCT ...)") { val data = TestData(1, "val_1") :: TestData(2, "val_2") :: Nil - val rdd = sqlContext.sparkContext.parallelize((0 to 1).map(i => data(i))) + val rdd = sparkContext.parallelize((0 to 1).map(i => data(i))) rdd.toDF().registerTempTable("distinctData") checkAnswer(sql("SELECT COUNT(DISTINCT key,value) FROM distinctData"), Row(2)) } @@ -1432,14 +1432,14 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("SPARK-4699 case sensitivity SQL query") { sqlContext.setConf(SQLConf.CASE_SENSITIVE, false) val data = TestData(1, "val_1") :: TestData(2, "val_2") :: Nil - val rdd = sqlContext.sparkContext.parallelize((0 to 1).map(i => data(i))) + val rdd = sparkContext.parallelize((0 to 1).map(i => data(i))) rdd.toDF().registerTempTable("testTable1") checkAnswer(sql("SELECT VALUE FROM TESTTABLE1 where KEY = 1"), Row("val_1")) sqlContext.setConf(SQLConf.CASE_SENSITIVE, true) } test("SPARK-6145: ORDER BY test for nested fields") { - sqlContext.read.json(sqlContext.sparkContext.makeRDD( + sqlContext.read.json(sparkContext.makeRDD( """{"a": {"b": 1, "a": {"a": 1}}, "c": [{"d": 1}]}""" :: Nil)) .registerTempTable("nestedOrder") @@ -1452,14 +1452,14 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-6145: special cases") { - sqlContext.read.json(sqlContext.sparkContext.makeRDD( + sqlContext.read.json(sparkContext.makeRDD( """{"a": {"b": [1]}, "b": [{"a": 1}], "_c0": {"a": 1}}""" :: Nil)).registerTempTable("t") checkAnswer(sql("SELECT a.b[0] FROM t ORDER BY _c0.a"), Row(1)) checkAnswer(sql("SELECT b[0].a FROM t ORDER BY _c0.a"), Row(1)) } test("SPARK-6898: complete support for special chars in column names") { - sqlContext.read.json(sqlContext.sparkContext.makeRDD( + sqlContext.read.json(sparkContext.makeRDD( """{"a": {"c.b": 1}, "b.$q": [{"a@!.q": 1}], "q.w": {"w.i&": [1]}}""" :: Nil)) .registerTempTable("t") @@ -1543,7 +1543,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("SPARK-7067: order by queries for complex ExtractValue chain") { withTempTable("t") { - sqlContext.read.json(sqlContext.sparkContext.makeRDD( + sqlContext.read.json(sparkContext.makeRDD( """{"a": {"b": [{"c": 1}]}, "b": [{"d": 1}]}""" :: Nil)).registerTempTable("t") checkAnswer(sql("SELECT a.b FROM t ORDER BY b[0].d"), Row(Seq(Row(1)))) } @@ -1610,8 +1610,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("aggregation with codegen updates peak execution memory") { withSQLConf((SQLConf.CODEGEN_ENABLED.key, "true")) { - val sc = sqlContext.sparkContext - AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "aggregation with codegen") { + AccumulatorSuite.verifyPeakExecutionMemorySet(sparkContext, "aggregation with codegen") { testCodeGen( "SELECT key, count(value) FROM testData GROUP BY key", (1 to 100).map(i => Row(i, 1))) @@ -1670,8 +1669,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("external sorting updates peak execution memory") { withSQLConf((SQLConf.EXTERNAL_SORT.key, "true")) { - val sc = sqlContext.sparkContext - AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "external sort") { + AccumulatorSuite.verifyPeakExecutionMemorySet(sparkContext, "external sort") { sortTest() } } @@ -1679,7 +1677,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("SPARK-9511: error with table starting with number") { withTempTable("1one") { - sqlContext.sparkContext.parallelize(1 to 10).map(i => (i, i.toString)) + sparkContext.parallelize(1 to 10).map(i => (i, i.toString)) .toDF("num", "str") .registerTempTable("1one") checkAnswer(sql("select count(num) from 1one"), Row(10)) @@ -1690,7 +1688,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { withTempPath { dir => val path = dir.getCanonicalPath val df = - sqlContext.sparkContext.parallelize(1 to 10).map(i => (i, i.toString)).toDF("num", "str") + sparkContext.parallelize(1 to 10).map(i => (i, i.toString)).toDF("num", "str") df .write .format("parquet") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala index 45d0ee4a8e..ddab918629 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.test.SharedSQLContext class SerializationSuite extends SparkFunSuite with SharedSQLContext { test("[SPARK-5235] SQLContext should be serializable") { - val _sqlContext = new SQLContext(sqlContext.sparkContext) + val _sqlContext = new SQLContext(sparkContext) new JavaSerializer(new SparkConf()).newInstance().serialize(_sqlContext) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index b91438baea..e12e6bea30 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -268,9 +268,7 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext { Row(3, 4)) intercept[AnalysisException] { - checkAnswer( - df.selectExpr("length(c)"), // int type of the argument is unacceptable - Row("5.0000")) + df.selectExpr("length(c)") // int type of the argument is unacceptable } } @@ -284,63 +282,46 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext { } test("number format function") { - val tuple = - ("aa", 1.asInstanceOf[Byte], 2.asInstanceOf[Short], - 3.13223f, 4, 5L, 6.48173d, Decimal(7.128381)) - val df = - Seq(tuple) - .toDF( - "a", // string "aa" - "b", // byte 1 - "c", // short 2 - "d", // float 3.13223f - "e", // integer 4 - "f", // long 5L - "g", // double 6.48173d - "h") // decimal 7.128381 - - checkAnswer( - df.select(format_number($"f", 4)), + val df = sqlContext.range(1) + + checkAnswer( + df.select(format_number(lit(5L), 4)), Row("5.0000")) checkAnswer( - df.selectExpr("format_number(b, e)"), // convert the 1st argument to integer + df.select(format_number(lit(1.toByte), 4)), // convert the 1st argument to integer Row("1.0000")) checkAnswer( - df.selectExpr("format_number(c, e)"), // convert the 1st argument to integer + df.select(format_number(lit(2.toShort), 4)), // convert the 1st argument to integer Row("2.0000")) checkAnswer( - df.selectExpr("format_number(d, e)"), // convert the 1st argument to double + df.select(format_number(lit(3.1322.toFloat), 4)), // convert the 1st argument to double Row("3.1322")) checkAnswer( - df.selectExpr("format_number(e, e)"), // not convert anything + df.select(format_number(lit(4), 4)), // not convert anything Row("4.0000")) checkAnswer( - df.selectExpr("format_number(f, e)"), // not convert anything + df.select(format_number(lit(5L), 4)), // not convert anything Row("5.0000")) checkAnswer( - df.selectExpr("format_number(g, e)"), // not convert anything + df.select(format_number(lit(6.48173), 4)), // not convert anything Row("6.4817")) checkAnswer( - df.selectExpr("format_number(h, e)"), // not convert anything + df.select(format_number(lit(BigDecimal(7.128381)), 4)), // not convert anything Row("7.1284")) intercept[AnalysisException] { - checkAnswer( - df.selectExpr("format_number(a, e)"), // string type of the 1st argument is unacceptable - Row("5.0000")) + df.select(format_number(lit("aa"), 4)) // string type of the 1st argument is unacceptable } intercept[AnalysisException] { - checkAnswer( - df.selectExpr("format_number(e, g)"), // decimal type of the 2nd argument is unacceptable - Row("5.0000")) + df.selectExpr("format_number(4, 6.48173)") // non-integral type 2nd argument is unacceptable } // for testing the mutable state of the expression in code gen. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index eb275af101..e0435a0dba 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -26,7 +26,7 @@ class UDFSuite extends QueryTest with SharedSQLContext { import testImplicits._ test("built-in fixed arity expressions") { - val df = ctx.emptyDataFrame + val df = sqlContext.emptyDataFrame df.selectExpr("rand()", "randn()", "rand(5)", "randn(50)") } @@ -55,23 +55,23 @@ class UDFSuite extends QueryTest with SharedSQLContext { val df = Seq((1, "Tearing down the walls that divide us")).toDF("id", "saying") df.registerTempTable("tmp_table") checkAnswer(sql("select spark_partition_id() from tmp_table").toDF(), Row(0)) - ctx.dropTempTable("tmp_table") + sqlContext.dropTempTable("tmp_table") } test("SPARK-8005 input_file_name") { withTempPath { dir => - val data = ctx.sparkContext.parallelize(0 to 10, 2).toDF("id") + val data = sparkContext.parallelize(0 to 10, 2).toDF("id") data.write.parquet(dir.getCanonicalPath) - ctx.read.parquet(dir.getCanonicalPath).registerTempTable("test_table") + sqlContext.read.parquet(dir.getCanonicalPath).registerTempTable("test_table") val answer = sql("select input_file_name() from test_table").head().getString(0) assert(answer.contains(dir.getCanonicalPath)) assert(sql("select input_file_name() from test_table").distinct().collect().length >= 2) - ctx.dropTempTable("test_table") + sqlContext.dropTempTable("test_table") } } test("error reporting for incorrect number of arguments") { - val df = ctx.emptyDataFrame + val df = sqlContext.emptyDataFrame val e = intercept[AnalysisException] { df.selectExpr("substr('abcd', 2, 3, 4)") } @@ -79,7 +79,7 @@ class UDFSuite extends QueryTest with SharedSQLContext { } test("error reporting for undefined functions") { - val df = ctx.emptyDataFrame + val df = sqlContext.emptyDataFrame val e = intercept[AnalysisException] { df.selectExpr("a_function_that_does_not_exist()") } @@ -87,24 +87,24 @@ class UDFSuite extends QueryTest with SharedSQLContext { } test("Simple UDF") { - ctx.udf.register("strLenScala", (_: String).length) + sqlContext.udf.register("strLenScala", (_: String).length) assert(sql("SELECT strLenScala('test')").head().getInt(0) === 4) } test("ZeroArgument UDF") { - ctx.udf.register("random0", () => { Math.random()}) + sqlContext.udf.register("random0", () => { Math.random()}) assert(sql("SELECT random0()").head().getDouble(0) >= 0.0) } test("TwoArgument UDF") { - ctx.udf.register("strLenScala", (_: String).length + (_: Int)) + sqlContext.udf.register("strLenScala", (_: String).length + (_: Int)) assert(sql("SELECT strLenScala('test', 1)").head().getInt(0) === 5) } test("UDF in a WHERE") { - ctx.udf.register("oneArgFilter", (n: Int) => { n > 80 }) + sqlContext.udf.register("oneArgFilter", (n: Int) => { n > 80 }) - val df = ctx.sparkContext.parallelize( + val df = sparkContext.parallelize( (1 to 100).map(i => TestData(i, i.toString))).toDF() df.registerTempTable("integerData") @@ -114,7 +114,7 @@ class UDFSuite extends QueryTest with SharedSQLContext { } test("UDF in a HAVING") { - ctx.udf.register("havingFilter", (n: Long) => { n > 5 }) + sqlContext.udf.register("havingFilter", (n: Long) => { n > 5 }) val df = Seq(("red", 1), ("red", 2), ("blue", 10), ("green", 100), ("green", 200)).toDF("g", "v") @@ -133,7 +133,7 @@ class UDFSuite extends QueryTest with SharedSQLContext { } test("UDF in a GROUP BY") { - ctx.udf.register("groupFunction", (n: Int) => { n > 10 }) + sqlContext.udf.register("groupFunction", (n: Int) => { n > 10 }) val df = Seq(("red", 1), ("red", 2), ("blue", 10), ("green", 100), ("green", 200)).toDF("g", "v") @@ -150,10 +150,10 @@ class UDFSuite extends QueryTest with SharedSQLContext { } test("UDFs everywhere") { - ctx.udf.register("groupFunction", (n: Int) => { n > 10 }) - ctx.udf.register("havingFilter", (n: Long) => { n > 2000 }) - ctx.udf.register("whereFilter", (n: Int) => { n < 150 }) - ctx.udf.register("timesHundred", (n: Long) => { n * 100 }) + sqlContext.udf.register("groupFunction", (n: Int) => { n > 10 }) + sqlContext.udf.register("havingFilter", (n: Long) => { n > 2000 }) + sqlContext.udf.register("whereFilter", (n: Int) => { n < 150 }) + sqlContext.udf.register("timesHundred", (n: Long) => { n * 100 }) val df = Seq(("red", 1), ("red", 2), ("blue", 10), ("green", 100), ("green", 200)).toDF("g", "v") @@ -172,7 +172,7 @@ class UDFSuite extends QueryTest with SharedSQLContext { } test("struct UDF") { - ctx.udf.register("returnStruct", (f1: String, f2: String) => FunctionResult(f1, f2)) + sqlContext.udf.register("returnStruct", (f1: String, f2: String) => FunctionResult(f1, f2)) val result = sql("SELECT returnStruct('test', 'test2') as ret") @@ -181,13 +181,13 @@ class UDFSuite extends QueryTest with SharedSQLContext { } test("udf that is transformed") { - ctx.udf.register("makeStruct", (x: Int, y: Int) => (x, y)) + sqlContext.udf.register("makeStruct", (x: Int, y: Int) => (x, y)) // 1 + 1 is constant folded causing a transformation. assert(sql("SELECT makeStruct(1 + 1, 2)").first().getAs[Row](0) === Row(2, 2)) } test("type coercion for udf inputs") { - ctx.udf.register("intExpected", (x: Int) => x) + sqlContext.udf.register("intExpected", (x: Int) => x) // pass a decimal to intExpected. assert(sql("SELECT intExpected(1.0)").head().getInt(0) === 1) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index b6d279ae47..fa8f9c8e00 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -90,7 +90,7 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext { } test("UDTs and UDFs") { - ctx.udf.register("testType", (d: MyDenseVector) => d.isInstanceOf[MyDenseVector]) + sqlContext.udf.register("testType", (d: MyDenseVector) => d.isInstanceOf[MyDenseVector]) pointsRDD.registerTempTable("points") checkAnswer( sql("SELECT testType(features) from points"), @@ -148,8 +148,8 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext { StructField("vec", new MyDenseVectorUDT, false) )) - val stringRDD = ctx.sparkContext.parallelize(data) - val jsonRDD = ctx.read.schema(schema).json(stringRDD) + val stringRDD = sparkContext.parallelize(data) + val jsonRDD = sqlContext.read.schema(schema).json(stringRDD) checkAnswer( jsonRDD, Row(1, new MyDenseVector(Array(1.1, 2.2, 3.3, 4.4))) :: diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala index 83db9b6510..cd3644eb9c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala @@ -31,7 +31,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { setupTestData() test("simple columnar query") { - val plan = ctx.executePlan(testData.logicalPlan).executedPlan + val plan = sqlContext.executePlan(testData.logicalPlan).executedPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None) checkAnswer(scan, testData.collect().toSeq) @@ -39,16 +39,16 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { test("default size avoids broadcast") { // TODO: Improve this test when we have better statistics - ctx.sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString)) + sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString)) .toDF().registerTempTable("sizeTst") - ctx.cacheTable("sizeTst") + sqlContext.cacheTable("sizeTst") assert( - ctx.table("sizeTst").queryExecution.analyzed.statistics.sizeInBytes > - ctx.conf.autoBroadcastJoinThreshold) + sqlContext.table("sizeTst").queryExecution.analyzed.statistics.sizeInBytes > + sqlContext.conf.autoBroadcastJoinThreshold) } test("projection") { - val plan = ctx.executePlan(testData.select('value, 'key).logicalPlan).executedPlan + val plan = sqlContext.executePlan(testData.select('value, 'key).logicalPlan).executedPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None) checkAnswer(scan, testData.collect().map { @@ -57,7 +57,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-1436 regression: in-memory columns must be able to be accessed multiple times") { - val plan = ctx.executePlan(testData.logicalPlan).executedPlan + val plan = sqlContext.executePlan(testData.logicalPlan).executedPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None) checkAnswer(scan, testData.collect().toSeq) @@ -69,7 +69,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { sql("SELECT * FROM repeatedData"), repeatedData.collect().toSeq.map(Row.fromTuple)) - ctx.cacheTable("repeatedData") + sqlContext.cacheTable("repeatedData") checkAnswer( sql("SELECT * FROM repeatedData"), @@ -81,7 +81,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { sql("SELECT * FROM nullableRepeatedData"), nullableRepeatedData.collect().toSeq.map(Row.fromTuple)) - ctx.cacheTable("nullableRepeatedData") + sqlContext.cacheTable("nullableRepeatedData") checkAnswer( sql("SELECT * FROM nullableRepeatedData"), @@ -96,7 +96,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { sql("SELECT time FROM timestamps"), timestamps.collect().toSeq) - ctx.cacheTable("timestamps") + sqlContext.cacheTable("timestamps") checkAnswer( sql("SELECT time FROM timestamps"), @@ -108,7 +108,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { sql("SELECT * FROM withEmptyParts"), withEmptyParts.collect().toSeq.map(Row.fromTuple)) - ctx.cacheTable("withEmptyParts") + sqlContext.cacheTable("withEmptyParts") checkAnswer( sql("SELECT * FROM withEmptyParts"), @@ -157,7 +157,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { // Create a RDD for the schema val rdd = - ctx.sparkContext.parallelize((1 to 100), 10).map { i => + sparkContext.parallelize((1 to 100), 10).map { i => Row( s"str${i}: test cache.", s"binary${i}: test cache.".getBytes("UTF-8"), @@ -177,24 +177,24 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { (0 to i).map(j => s"map_key_$j" -> (Long.MaxValue - j)).toMap, Row((i - 0.25).toFloat, Seq(true, false, null))) } - ctx.createDataFrame(rdd, schema).registerTempTable("InMemoryCache_different_data_types") + sqlContext.createDataFrame(rdd, schema).registerTempTable("InMemoryCache_different_data_types") // Cache the table. sql("cache table InMemoryCache_different_data_types") // Make sure the table is indeed cached. - val tableScan = ctx.table("InMemoryCache_different_data_types").queryExecution.executedPlan + sqlContext.table("InMemoryCache_different_data_types").queryExecution.executedPlan assert( - ctx.isCached("InMemoryCache_different_data_types"), + sqlContext.isCached("InMemoryCache_different_data_types"), "InMemoryCache_different_data_types should be cached.") // Issue a query and check the results. checkAnswer( sql(s"SELECT DISTINCT ${allColumns} FROM InMemoryCache_different_data_types"), - ctx.table("InMemoryCache_different_data_types").collect()) - ctx.dropTempTable("InMemoryCache_different_data_types") + sqlContext.table("InMemoryCache_different_data_types").collect()) + sqlContext.dropTempTable("InMemoryCache_different_data_types") } test("SPARK-10422: String column in InMemoryColumnarCache needs to override clone method") { - val df = - ctx.range(1, 100).selectExpr("id % 10 as id").rdd.map(id => Tuple1(s"str_$id")).toDF("i") + val df = sqlContext.range(1, 100).selectExpr("id % 10 as id") + .rdd.map(id => Tuple1(s"str_$id")).toDF("i") val cached = df.cache() // count triggers the caching action. It should not throw. cached.count() @@ -205,7 +205,8 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { // Check result. checkAnswer( cached, - ctx.range(1, 100).selectExpr("id % 10 as id").rdd.map(id => Tuple1(s"str_$id")).toDF("i") + sqlContext.range(1, 100).selectExpr("id % 10 as id") + .rdd.map(id => Tuple1(s"str_$id")).toDF("i") ) // Drop the cache. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala index ab2644eb45..6b7401464f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala @@ -25,32 +25,32 @@ import org.apache.spark.sql.test.SQLTestData._ class PartitionBatchPruningSuite extends SparkFunSuite with SharedSQLContext { import testImplicits._ - private lazy val originalColumnBatchSize = ctx.conf.columnBatchSize - private lazy val originalInMemoryPartitionPruning = ctx.conf.inMemoryPartitionPruning + private lazy val originalColumnBatchSize = sqlContext.conf.columnBatchSize + private lazy val originalInMemoryPartitionPruning = sqlContext.conf.inMemoryPartitionPruning override protected def beforeAll(): Unit = { super.beforeAll() // Make a table with 5 partitions, 2 batches per partition, 10 elements per batch - ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, 10) + sqlContext.setConf(SQLConf.COLUMN_BATCH_SIZE, 10) - val pruningData = ctx.sparkContext.makeRDD((1 to 100).map { key => + val pruningData = sparkContext.makeRDD((1 to 100).map { key => val string = if (((key - 1) / 10) % 2 == 0) null else key.toString TestData(key, string) }, 5).toDF() pruningData.registerTempTable("pruningData") // Enable in-memory partition pruning - ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, true) + sqlContext.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, true) // Enable in-memory table scan accumulators - ctx.setConf("spark.sql.inMemoryTableScanStatistics.enable", "true") - ctx.cacheTable("pruningData") + sqlContext.setConf("spark.sql.inMemoryTableScanStatistics.enable", "true") + sqlContext.cacheTable("pruningData") } override protected def afterAll(): Unit = { try { - ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize) - ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning) - ctx.uncacheTable("pruningData") + sqlContext.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize) + sqlContext.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning) + sqlContext.uncacheTable("pruningData") } finally { super.afterAll() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala index 8998f51111..911d12e93e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala @@ -22,6 +22,8 @@ import org.apache.spark.sql.catalyst.plans.physical.SinglePartition import org.apache.spark.sql.test.SharedSQLContext class ExchangeSuite extends SparkPlanTest with SharedSQLContext { + import testImplicits.localSeqToDataFrameHolder + test("shuffling UnsafeRows in exchange") { val input = (1 to 1000).map(Tuple1.apply) checkAnswer( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index fad93b014c..cafa1d5154 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.execution -import org.apache.spark.SparkFunSuite import org.apache.spark.rdd.RDD import org.apache.spark.sql.{execution, Row, SQLConf} import org.apache.spark.sql.catalyst.InternalRow @@ -31,14 +30,14 @@ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -class PlannerSuite extends SparkFunSuite with SharedSQLContext { +class PlannerSuite extends SharedSQLContext { import testImplicits._ setupTestData() private def testPartialAggregationPlan(query: LogicalPlan): Unit = { - val _ctx = ctx - import _ctx.planner._ + val planner = sqlContext.planner + import planner._ val plannedOption = HashAggregation(query).headOption.orElse(Aggregation(query).headOption) val planned = plannedOption.getOrElse( @@ -53,8 +52,8 @@ class PlannerSuite extends SparkFunSuite with SharedSQLContext { } test("unions are collapsed") { - val _ctx = ctx - import _ctx.planner._ + val planner = sqlContext.planner + import planner._ val query = testData.unionAll(testData).unionAll(testData).logicalPlan val planned = BasicOperators(query).head val logicalUnions = query collect { case u: logical.Union => u } @@ -81,33 +80,30 @@ class PlannerSuite extends SparkFunSuite with SharedSQLContext { } test("sizeInBytes estimation of limit operator for broadcast hash join optimization") { - def checkPlan(fieldTypes: Seq[DataType], newThreshold: Int): Unit = { - ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, newThreshold) - val fields = fieldTypes.zipWithIndex.map { - case (dataType, index) => StructField(s"c${index}", dataType, true) - } :+ StructField("key", IntegerType, true) - val schema = StructType(fields) - val row = Row.fromSeq(Seq.fill(fields.size)(null)) - val rowRDD = ctx.sparkContext.parallelize(row :: Nil) - ctx.createDataFrame(rowRDD, schema).registerTempTable("testLimit") - - val planned = sql( - """ - |SELECT l.a, l.b - |FROM testData2 l JOIN (SELECT * FROM testLimit LIMIT 1) r ON (l.a = r.key) - """.stripMargin).queryExecution.executedPlan - - val broadcastHashJoins = planned.collect { case join: BroadcastHashJoin => join } - val shuffledHashJoins = planned.collect { case join: ShuffledHashJoin => join } - - assert(broadcastHashJoins.size === 1, "Should use broadcast hash join") - assert(shuffledHashJoins.isEmpty, "Should not use shuffled hash join") - - ctx.dropTempTable("testLimit") + def checkPlan(fieldTypes: Seq[DataType]): Unit = { + withTempTable("testLimit") { + val fields = fieldTypes.zipWithIndex.map { + case (dataType, index) => StructField(s"c${index}", dataType, true) + } :+ StructField("key", IntegerType, true) + val schema = StructType(fields) + val row = Row.fromSeq(Seq.fill(fields.size)(null)) + val rowRDD = sparkContext.parallelize(row :: Nil) + sqlContext.createDataFrame(rowRDD, schema).registerTempTable("testLimit") + + val planned = sql( + """ + |SELECT l.a, l.b + |FROM testData2 l JOIN (SELECT * FROM testLimit LIMIT 1) r ON (l.a = r.key) + """.stripMargin).queryExecution.executedPlan + + val broadcastHashJoins = planned.collect { case join: BroadcastHashJoin => join } + val shuffledHashJoins = planned.collect { case join: ShuffledHashJoin => join } + + assert(broadcastHashJoins.size === 1, "Should use broadcast hash join") + assert(shuffledHashJoins.isEmpty, "Should not use shuffled hash join") + } } - val origThreshold = ctx.conf.autoBroadcastJoinThreshold - val simpleTypes = NullType :: BooleanType :: @@ -124,7 +120,9 @@ class PlannerSuite extends SparkFunSuite with SharedSQLContext { StringType :: BinaryType :: Nil - checkPlan(simpleTypes, newThreshold = 16434) + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "16434") { + checkPlan(simpleTypes) + } val complexTypes = ArrayType(DoubleType, true) :: @@ -136,36 +134,37 @@ class PlannerSuite extends SparkFunSuite with SharedSQLContext { StructField("b", ArrayType(DoubleType), nullable = false), StructField("c", DoubleType, nullable = false))) :: Nil - checkPlan(complexTypes, newThreshold = 901617) - - ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold) + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "901617") { + checkPlan(complexTypes) + } } test("InMemoryRelation statistics propagation") { - val origThreshold = ctx.conf.autoBroadcastJoinThreshold - ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, 81920) - - testData.limit(3).registerTempTable("tiny") - sql("CACHE TABLE tiny") + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "81920") { + withTempTable("tiny") { + testData.limit(3).registerTempTable("tiny") + sql("CACHE TABLE tiny") - val a = testData.as("a") - val b = ctx.table("tiny").as("b") - val planned = a.join(b, $"a.key" === $"b.key").queryExecution.executedPlan + val a = testData.as("a") + val b = sqlContext.table("tiny").as("b") + val planned = a.join(b, $"a.key" === $"b.key").queryExecution.executedPlan - val broadcastHashJoins = planned.collect { case join: BroadcastHashJoin => join } - val shuffledHashJoins = planned.collect { case join: ShuffledHashJoin => join } + val broadcastHashJoins = planned.collect { case join: BroadcastHashJoin => join } + val shuffledHashJoins = planned.collect { case join: ShuffledHashJoin => join } - assert(broadcastHashJoins.size === 1, "Should use broadcast hash join") - assert(shuffledHashJoins.isEmpty, "Should not use shuffled hash join") + assert(broadcastHashJoins.size === 1, "Should use broadcast hash join") + assert(shuffledHashJoins.isEmpty, "Should not use shuffled hash join") - ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold) + sqlContext.clearCache() + } + } } test("efficient limit -> project -> sort") { { val query = testData.select('key, 'value).sort('key).limit(2).logicalPlan - val planned = ctx.planner.TakeOrderedAndProject(query) + val planned = sqlContext.planner.TakeOrderedAndProject(query) assert(planned.head.isInstanceOf[execution.TakeOrderedAndProject]) assert(planned.head.output === testData.select('key, 'value).logicalPlan.output) } @@ -175,7 +174,7 @@ class PlannerSuite extends SparkFunSuite with SharedSQLContext { // into it. val query = testData.select('key, 'value).sort('key).select('value, 'key).limit(2).logicalPlan - val planned = ctx.planner.TakeOrderedAndProject(query) + val planned = sqlContext.planner.TakeOrderedAndProject(query) assert(planned.head.isInstanceOf[execution.TakeOrderedAndProject]) assert(planned.head.output === testData.select('value, 'key).logicalPlan.output) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala index ef6ad59b71..4492e37ad0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala @@ -39,20 +39,20 @@ class RowFormatConvertersSuite extends SparkPlanTest with SharedSQLContext { test("planner should insert unsafe->safe conversions when required") { val plan = Limit(10, outputsUnsafe) - val preparedPlan = ctx.prepareForExecution.execute(plan) + val preparedPlan = sqlContext.prepareForExecution.execute(plan) assert(preparedPlan.children.head.isInstanceOf[ConvertToSafe]) } test("filter can process unsafe rows") { val plan = Filter(IsNull(IsNull(Literal(1))), outputsUnsafe) - val preparedPlan = ctx.prepareForExecution.execute(plan) + val preparedPlan = sqlContext.prepareForExecution.execute(plan) assert(getConverters(preparedPlan).size === 1) assert(preparedPlan.outputsUnsafeRows) } test("filter can process safe rows") { val plan = Filter(IsNull(IsNull(Literal(1))), outputsSafe) - val preparedPlan = ctx.prepareForExecution.execute(plan) + val preparedPlan = sqlContext.prepareForExecution.execute(plan) assert(getConverters(preparedPlan).isEmpty) assert(!preparedPlan.outputsUnsafeRows) } @@ -67,33 +67,33 @@ class RowFormatConvertersSuite extends SparkPlanTest with SharedSQLContext { test("union requires all of its input rows' formats to agree") { val plan = Union(Seq(outputsSafe, outputsUnsafe)) assert(plan.canProcessSafeRows && plan.canProcessUnsafeRows) - val preparedPlan = ctx.prepareForExecution.execute(plan) + val preparedPlan = sqlContext.prepareForExecution.execute(plan) assert(preparedPlan.outputsUnsafeRows) } test("union can process safe rows") { val plan = Union(Seq(outputsSafe, outputsSafe)) - val preparedPlan = ctx.prepareForExecution.execute(plan) + val preparedPlan = sqlContext.prepareForExecution.execute(plan) assert(!preparedPlan.outputsUnsafeRows) } test("union can process unsafe rows") { val plan = Union(Seq(outputsUnsafe, outputsUnsafe)) - val preparedPlan = ctx.prepareForExecution.execute(plan) + val preparedPlan = sqlContext.prepareForExecution.execute(plan) assert(preparedPlan.outputsUnsafeRows) } test("round trip with ConvertToUnsafe and ConvertToSafe") { val input = Seq(("hello", 1), ("world", 2)) checkAnswer( - ctx.createDataFrame(input), + sqlContext.createDataFrame(input), plan => ConvertToSafe(ConvertToUnsafe(plan)), input.map(Row.fromTuple) ) } test("SPARK-9683: copy UTF8String when convert unsafe array/map to safe") { - SparkPlan.currentContext.set(ctx) + SparkPlan.currentContext.set(sqlContext) val schema = ArrayType(StringType) val rows = (1 to 100).map { i => InternalRow(new GenericArrayData(Array[Any](UTF8String.fromString(i.toString)))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala index 8fa77b0fcb..3073d492e6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.test.SharedSQLContext class SortSuite extends SparkPlanTest with SharedSQLContext { + import testImplicits.localSeqToDataFrameHolder // This test was originally added as an example of how to use [[SparkPlanTest]]; // it's not designed to be a comprehensive test of ExternalSort. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala index 5ab8f44fae..de45ae4635 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala @@ -31,14 +31,7 @@ import org.apache.spark.sql.test.SQLTestUtils * class's test helper methods can be used, see [[SortSuite]]. */ private[sql] abstract class SparkPlanTest extends SparkFunSuite { - protected def _sqlContext: SQLContext - - /** - * Creates a DataFrame from a local Seq of Product. - */ - implicit def localSeqToDataFrameHolder[A <: Product : TypeTag](data: Seq[A]): DataFrameHolder = { - _sqlContext.implicits.localSeqToDataFrameHolder(data) - } + protected def sqlContext: SQLContext /** * Runs the plan and makes sure the answer matches the expected result. @@ -98,7 +91,7 @@ private[sql] abstract class SparkPlanTest extends SparkFunSuite { planFunction: Seq[SparkPlan] => SparkPlan, expectedAnswer: Seq[Row], sortAnswers: Boolean = true): Unit = { - SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer, sortAnswers, _sqlContext) match { + SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer, sortAnswers, sqlContext) match { case Some(errorMessage) => fail(errorMessage) case None => } @@ -122,7 +115,7 @@ private[sql] abstract class SparkPlanTest extends SparkFunSuite { expectedPlanFunction: SparkPlan => SparkPlan, sortAnswers: Boolean = true): Unit = { SparkPlanTest.checkAnswer( - input, planFunction, expectedPlanFunction, sortAnswers, _sqlContext) match { + input, planFunction, expectedPlanFunction, sortAnswers, sqlContext) match { case Some(errorMessage) => fail(errorMessage) case None => } @@ -149,13 +142,13 @@ object SparkPlanTest { planFunction: SparkPlan => SparkPlan, expectedPlanFunction: SparkPlan => SparkPlan, sortAnswers: Boolean, - _sqlContext: SQLContext): Option[String] = { + sqlContext: SQLContext): Option[String] = { val outputPlan = planFunction(input.queryExecution.sparkPlan) val expectedOutputPlan = expectedPlanFunction(input.queryExecution.sparkPlan) val expectedAnswer: Seq[Row] = try { - executePlan(expectedOutputPlan, _sqlContext) + executePlan(expectedOutputPlan, sqlContext) } catch { case NonFatal(e) => val errorMessage = @@ -170,7 +163,7 @@ object SparkPlanTest { } val actualAnswer: Seq[Row] = try { - executePlan(outputPlan, _sqlContext) + executePlan(outputPlan, sqlContext) } catch { case NonFatal(e) => val errorMessage = @@ -210,12 +203,12 @@ object SparkPlanTest { planFunction: Seq[SparkPlan] => SparkPlan, expectedAnswer: Seq[Row], sortAnswers: Boolean, - _sqlContext: SQLContext): Option[String] = { + sqlContext: SQLContext): Option[String] = { val outputPlan = planFunction(input.map(_.queryExecution.sparkPlan)) val sparkAnswer: Seq[Row] = try { - executePlan(outputPlan, _sqlContext) + executePlan(outputPlan, sqlContext) } catch { case NonFatal(e) => val errorMessage = @@ -238,10 +231,10 @@ object SparkPlanTest { } } - private def executePlan(outputPlan: SparkPlan, _sqlContext: SQLContext): Seq[Row] = { + private def executePlan(outputPlan: SparkPlan, sqlContext: SQLContext): Seq[Row] = { // A very simple resolver to make writing tests easier. In contrast to the real resolver // this is always case sensitive and does not try to handle scoping or complex type resolution. - val resolvedPlan = _sqlContext.prepareForExecution.execute( + val resolvedPlan = sqlContext.prepareForExecution.execute( outputPlan transform { case plan: SparkPlan => val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala index 3158458edb..7a0f0dfd2b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala @@ -29,15 +29,16 @@ import org.apache.spark.sql.types._ * A test suite that generates randomized data to test the [[TungstenSort]] operator. */ class TungstenSortSuite extends SparkPlanTest with SharedSQLContext { + import testImplicits.localSeqToDataFrameHolder override def beforeAll(): Unit = { super.beforeAll() - ctx.conf.setConf(SQLConf.CODEGEN_ENABLED, true) + sqlContext.conf.setConf(SQLConf.CODEGEN_ENABLED, true) } override def afterAll(): Unit = { try { - ctx.conf.setConf(SQLConf.CODEGEN_ENABLED, SQLConf.CODEGEN_ENABLED.defaultValue.get) + sqlContext.conf.unsetConf(SQLConf.CODEGEN_ENABLED) } finally { super.afterAll() } @@ -64,8 +65,7 @@ class TungstenSortSuite extends SparkPlanTest with SharedSQLContext { } test("sorting updates peak execution memory") { - val sc = ctx.sparkContext - AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "unsafe external sort") { + AccumulatorSuite.verifyPeakExecutionMemorySet(sparkContext, "unsafe external sort") { checkThatPlansAgree( (1 to 100).map(v => Tuple1(v)).toDF("a"), (child: SparkPlan) => TungstenSort('a.asc :: Nil, true, child), @@ -83,8 +83,8 @@ class TungstenSortSuite extends SparkPlanTest with SharedSQLContext { ) { test(s"sorting on $dataType with nullable=$nullable, sortOrder=$sortOrder") { val inputData = Seq.fill(1000)(randomDataGenerator()) - val inputDf = ctx.createDataFrame( - ctx.sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))), + val inputDf = sqlContext.createDataFrame( + sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))), StructType(StructField("a", dataType, nullable = true) :: Nil) ) assert(TungstenSort.supportsSchema(inputDf.schema)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala index 5fdb82b067..afda0d29f6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala @@ -37,7 +37,7 @@ class TungstenAggregationIteratorSuite extends SparkFunSuite with SharedSQLConte val newMutableProjection = (expr: Seq[Expression], schema: Seq[Attribute]) => { () => new InterpretedMutableProjection(expr, schema) } - val dummyAccum = SQLMetrics.createLongMetric(ctx.sparkContext, "dummy") + val dummyAccum = SQLMetrics.createLongMetric(sparkContext, "dummy") iter = new TungstenAggregationIterator(Seq.empty, Seq.empty, Seq.empty, 0, Seq.empty, newMutableProjection, Seq.empty, None, dummyAccum, dummyAccum) val numPages = iter.getHashMap.getNumDataPages diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 1174b27732..6a18cc6d27 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -215,7 +215,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("Complex field and type inferring with null in sampling") { - val jsonDF = ctx.read.json(jsonNullStruct) + val jsonDF = sqlContext.read.json(jsonNullStruct) val expectedSchema = StructType( StructField("headers", StructType( StructField("Charset", StringType, true) :: @@ -234,7 +234,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("Primitive field and type inferring") { - val jsonDF = ctx.read.json(primitiveFieldAndType) + val jsonDF = sqlContext.read.json(primitiveFieldAndType) val expectedSchema = StructType( StructField("bigInteger", DecimalType(20, 0), true) :: @@ -262,7 +262,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("Complex field and type inferring") { - val jsonDF = ctx.read.json(complexFieldAndType1) + val jsonDF = sqlContext.read.json(complexFieldAndType1) val expectedSchema = StructType( StructField("arrayOfArray1", ArrayType(ArrayType(StringType, true), true), true) :: @@ -361,7 +361,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("GetField operation on complex data type") { - val jsonDF = ctx.read.json(complexFieldAndType1) + val jsonDF = sqlContext.read.json(complexFieldAndType1) jsonDF.registerTempTable("jsonTable") checkAnswer( @@ -377,7 +377,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("Type conflict in primitive field values") { - val jsonDF = ctx.read.json(primitiveFieldValueTypeConflict) + val jsonDF = sqlContext.read.json(primitiveFieldValueTypeConflict) val expectedSchema = StructType( StructField("num_bool", StringType, true) :: @@ -449,7 +449,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } ignore("Type conflict in primitive field values (Ignored)") { - val jsonDF = ctx.read.json(primitiveFieldValueTypeConflict) + val jsonDF = sqlContext.read.json(primitiveFieldValueTypeConflict) jsonDF.registerTempTable("jsonTable") // Right now, the analyzer does not promote strings in a boolean expression. @@ -502,7 +502,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("Type conflict in complex field values") { - val jsonDF = ctx.read.json(complexFieldValueTypeConflict) + val jsonDF = sqlContext.read.json(complexFieldValueTypeConflict) val expectedSchema = StructType( StructField("array", ArrayType(LongType, true), true) :: @@ -526,7 +526,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("Type conflict in array elements") { - val jsonDF = ctx.read.json(arrayElementTypeConflict) + val jsonDF = sqlContext.read.json(arrayElementTypeConflict) val expectedSchema = StructType( StructField("array1", ArrayType(StringType, true), true) :: @@ -554,7 +554,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("Handling missing fields") { - val jsonDF = ctx.read.json(missingFields) + val jsonDF = sqlContext.read.json(missingFields) val expectedSchema = StructType( StructField("a", BooleanType, true) :: @@ -573,9 +573,9 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val dir = Utils.createTempDir() dir.delete() val path = dir.getCanonicalFile.toURI.toString - ctx.sparkContext.parallelize(1 to 100) + sparkContext.parallelize(1 to 100) .map(i => s"""{"a": 1, "b": "str$i"}""").saveAsTextFile(path) - val jsonDF = ctx.read.option("samplingRatio", "0.49").json(path) + val jsonDF = sqlContext.read.option("samplingRatio", "0.49").json(path) val analyzed = jsonDF.queryExecution.analyzed assert( @@ -590,7 +590,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val schema = StructType(StructField("a", LongType, true) :: Nil) val logicalRelation = - ctx.read.schema(schema).json(path) + sqlContext.read.schema(schema).json(path) .queryExecution.analyzed.asInstanceOf[LogicalRelation] val relationWithSchema = logicalRelation.relation.asInstanceOf[JSONRelation] assert(relationWithSchema.paths === Array(path)) @@ -603,7 +603,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { dir.delete() val path = dir.getCanonicalPath primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) - val jsonDF = ctx.read.json(path) + val jsonDF = sqlContext.read.json(path) val expectedSchema = StructType( StructField("bigInteger", DecimalType(20, 0), true) :: @@ -672,7 +672,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { StructField("null", StringType, true) :: StructField("string", StringType, true) :: Nil) - val jsonDF1 = ctx.read.schema(schema).json(path) + val jsonDF1 = sqlContext.read.schema(schema).json(path) assert(schema === jsonDF1.schema) @@ -689,7 +689,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { "this is a simple string.") ) - val jsonDF2 = ctx.read.schema(schema).json(primitiveFieldAndType) + val jsonDF2 = sqlContext.read.schema(schema).json(primitiveFieldAndType) assert(schema === jsonDF2.schema) @@ -710,7 +710,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { test("Applying schemas with MapType") { val schemaWithSimpleMap = StructType( StructField("map", MapType(StringType, IntegerType, true), false) :: Nil) - val jsonWithSimpleMap = ctx.read.schema(schemaWithSimpleMap).json(mapType1) + val jsonWithSimpleMap = sqlContext.read.schema(schemaWithSimpleMap).json(mapType1) jsonWithSimpleMap.registerTempTable("jsonWithSimpleMap") @@ -738,7 +738,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val schemaWithComplexMap = StructType( StructField("map", MapType(StringType, innerStruct, true), false) :: Nil) - val jsonWithComplexMap = ctx.read.schema(schemaWithComplexMap).json(mapType2) + val jsonWithComplexMap = sqlContext.read.schema(schemaWithComplexMap).json(mapType2) jsonWithComplexMap.registerTempTable("jsonWithComplexMap") @@ -764,7 +764,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("SPARK-2096 Correctly parse dot notations") { - val jsonDF = ctx.read.json(complexFieldAndType2) + val jsonDF = sqlContext.read.json(complexFieldAndType2) jsonDF.registerTempTable("jsonTable") checkAnswer( @@ -782,7 +782,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("SPARK-3390 Complex arrays") { - val jsonDF = ctx.read.json(complexFieldAndType2) + val jsonDF = sqlContext.read.json(complexFieldAndType2) jsonDF.registerTempTable("jsonTable") checkAnswer( @@ -805,7 +805,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("SPARK-3308 Read top level JSON arrays") { - val jsonDF = ctx.read.json(jsonArray) + val jsonDF = sqlContext.read.json(jsonArray) jsonDF.registerTempTable("jsonTable") checkAnswer( @@ -823,64 +823,63 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { test("Corrupt records") { // Test if we can query corrupt records. - val oldColumnNameOfCorruptRecord = ctx.conf.columnNameOfCorruptRecord - ctx.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, "_unparsed") - - val jsonDF = ctx.read.json(corruptRecords) - jsonDF.registerTempTable("jsonTable") - - val schema = StructType( - StructField("_unparsed", StringType, true) :: - StructField("a", StringType, true) :: - StructField("b", StringType, true) :: - StructField("c", StringType, true) :: Nil) - - assert(schema === jsonDF.schema) - - // In HiveContext, backticks should be used to access columns starting with a underscore. - checkAnswer( - sql( - """ - |SELECT a, b, c, _unparsed - |FROM jsonTable - """.stripMargin), - Row(null, null, null, "{") :: - Row(null, null, null, "") :: - Row(null, null, null, """{"a":1, b:2}""") :: - Row(null, null, null, """{"a":{, b:3}""") :: - Row("str_a_4", "str_b_4", "str_c_4", null) :: - Row(null, null, null, "]") :: Nil - ) - - checkAnswer( - sql( - """ - |SELECT a, b, c - |FROM jsonTable - |WHERE _unparsed IS NULL - """.stripMargin), - Row("str_a_4", "str_b_4", "str_c_4") - ) - - checkAnswer( - sql( - """ - |SELECT _unparsed - |FROM jsonTable - |WHERE _unparsed IS NOT NULL - """.stripMargin), - Row("{") :: - Row("") :: - Row("""{"a":1, b:2}""") :: - Row("""{"a":{, b:3}""") :: - Row("]") :: Nil - ) - - ctx.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, oldColumnNameOfCorruptRecord) + withSQLConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD.key -> "_unparsed") { + withTempTable("jsonTable") { + val jsonDF = sqlContext.read.json(corruptRecords) + jsonDF.registerTempTable("jsonTable") + + val schema = StructType( + StructField("_unparsed", StringType, true) :: + StructField("a", StringType, true) :: + StructField("b", StringType, true) :: + StructField("c", StringType, true) :: Nil) + + assert(schema === jsonDF.schema) + + // In HiveContext, backticks should be used to access columns starting with a underscore. + checkAnswer( + sql( + """ + |SELECT a, b, c, _unparsed + |FROM jsonTable + """.stripMargin), + Row(null, null, null, "{") :: + Row(null, null, null, "") :: + Row(null, null, null, """{"a":1, b:2}""") :: + Row(null, null, null, """{"a":{, b:3}""") :: + Row("str_a_4", "str_b_4", "str_c_4", null) :: + Row(null, null, null, "]") :: Nil + ) + + checkAnswer( + sql( + """ + |SELECT a, b, c + |FROM jsonTable + |WHERE _unparsed IS NULL + """.stripMargin), + Row("str_a_4", "str_b_4", "str_c_4") + ) + + checkAnswer( + sql( + """ + |SELECT _unparsed + |FROM jsonTable + |WHERE _unparsed IS NOT NULL + """.stripMargin), + Row("{") :: + Row("") :: + Row("""{"a":1, b:2}""") :: + Row("""{"a":{, b:3}""") :: + Row("]") :: Nil + ) + } + } } test("SPARK-4068: nulls in arrays") { - val jsonDF = ctx.read.json(nullsInArrays) + val jsonDF = sqlContext.read.json(nullsInArrays) jsonDF.registerTempTable("jsonTable") val schema = StructType( @@ -926,7 +925,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { Row(values(0).toInt, values(1), values(2).toBoolean, r.split(",").toList, v5) } - val df1 = ctx.createDataFrame(rowRDD1, schema1) + val df1 = sqlContext.createDataFrame(rowRDD1, schema1) df1.registerTempTable("applySchema1") val df2 = df1.toDF val result = df2.toJSON.collect() @@ -949,7 +948,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { Row(Row(values(0).toInt, values(2).toBoolean), Map(values(1) -> v4)) } - val df3 = ctx.createDataFrame(rowRDD2, schema2) + val df3 = sqlContext.createDataFrame(rowRDD2, schema2) df3.registerTempTable("applySchema2") val df4 = df3.toDF val result2 = df4.toJSON.collect() @@ -957,8 +956,8 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { assert(result2(1) === "{\"f1\":{\"f11\":2,\"f12\":false},\"f2\":{\"B2\":null}}") assert(result2(3) === "{\"f1\":{\"f11\":4,\"f12\":true},\"f2\":{\"D4\":2147483644}}") - val jsonDF = ctx.read.json(primitiveFieldAndType) - val primTable = ctx.read.json(jsonDF.toJSON) + val jsonDF = sqlContext.read.json(primitiveFieldAndType) + val primTable = sqlContext.read.json(jsonDF.toJSON) primTable.registerTempTable("primativeTable") checkAnswer( sql("select * from primativeTable"), @@ -970,8 +969,8 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { "this is a simple string.") ) - val complexJsonDF = ctx.read.json(complexFieldAndType1) - val compTable = ctx.read.json(complexJsonDF.toJSON) + val complexJsonDF = sqlContext.read.json(complexFieldAndType1) + val compTable = sqlContext.read.json(complexJsonDF.toJSON) compTable.registerTempTable("complexTable") // Access elements of a primitive array. checkAnswer( @@ -1039,25 +1038,25 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { Some(empty), 1.0, Some(StructType(StructField("a", IntegerType, true) :: Nil)), - None, None)(ctx) + None, None)(sqlContext) val logicalRelation0 = LogicalRelation(relation0) val relation1 = new JSONRelation( Some(singleRow), 1.0, Some(StructType(StructField("a", IntegerType, true) :: Nil)), - None, None)(ctx) + None, None)(sqlContext) val logicalRelation1 = LogicalRelation(relation1) val relation2 = new JSONRelation( Some(singleRow), 0.5, Some(StructType(StructField("a", IntegerType, true) :: Nil)), - None, None)(ctx) + None, None)(sqlContext) val logicalRelation2 = LogicalRelation(relation2) val relation3 = new JSONRelation( Some(singleRow), 1.0, Some(StructType(StructField("b", IntegerType, true) :: Nil)), - None, None)(ctx) + None, None)(sqlContext) val logicalRelation3 = LogicalRelation(relation3) assert(relation0 !== relation1) @@ -1078,18 +1077,18 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { withTempPath(dir => { val path = dir.getCanonicalFile.toURI.toString - ctx.sparkContext.parallelize(1 to 100) + sparkContext.parallelize(1 to 100) .map(i => s"""{"a": 1, "b": "str$i"}""").saveAsTextFile(path) val d1 = ResolvedDataSource( - ctx, + sqlContext, userSpecifiedSchema = None, partitionColumns = Array.empty[String], provider = classOf[DefaultSource].getCanonicalName, options = Map("path" -> path)) val d2 = ResolvedDataSource( - ctx, + sqlContext, userSpecifiedSchema = None, partitionColumns = Array.empty[String], provider = classOf[DefaultSource].getCanonicalName, @@ -1105,24 +1104,21 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("SPARK-7565 MapType in JsonRDD") { - val oldColumnNameOfCorruptRecord = ctx.conf.columnNameOfCorruptRecord - ctx.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, "_unparsed") - - val schemaWithSimpleMap = StructType( - StructField("map", MapType(StringType, IntegerType, true), false) :: Nil) - try { - val temp = Utils.createTempDir().getPath - - val df = ctx.read.schema(schemaWithSimpleMap).json(mapType1) - df.write.mode("overwrite").parquet(temp) - // order of MapType is not defined - assert(ctx.read.parquet(temp).count() == 5) - - val df2 = ctx.read.json(corruptRecords) - df2.write.mode("overwrite").parquet(temp) - checkAnswer(ctx.read.parquet(temp), df2.collect()) - } finally { - ctx.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, oldColumnNameOfCorruptRecord) + withSQLConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD.key -> "_unparsed") { + withTempDir { dir => + val schemaWithSimpleMap = StructType( + StructField("map", MapType(StringType, IntegerType, true), false) :: Nil) + val df = sqlContext.read.schema(schemaWithSimpleMap).json(mapType1) + + val path = dir.getAbsolutePath + df.write.mode("overwrite").parquet(path) + // order of MapType is not defined + assert(sqlContext.read.parquet(path).count() == 5) + + val df2 = sqlContext.read.json(corruptRecords) + df2.write.mode("overwrite").parquet(path) + checkAnswer(sqlContext.read.parquet(path), df2.collect()) + } } } @@ -1142,19 +1138,19 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val d1 = new File(root, "d1=1") // root/dt=1/col1=abc val p1_col1 = makePartition( - ctx.sparkContext.parallelize(2 to 5).map(i => s"""{"a": 1, "b": "str$i"}"""), + sparkContext.parallelize(2 to 5).map(i => s"""{"a": 1, "b": "str$i"}"""), d1, "col1", "abc") // root/dt=1/col1=abd val p2 = makePartition( - ctx.sparkContext.parallelize(6 to 10).map(i => s"""{"a": 1, "b": "str$i"}"""), + sparkContext.parallelize(6 to 10).map(i => s"""{"a": 1, "b": "str$i"}"""), d1, "col1", "abd") - ctx.read.json(root.getAbsolutePath).registerTempTable("test_myjson_with_part") + sqlContext.read.json(root.getAbsolutePath).registerTempTable("test_myjson_with_part") checkAnswer(sql( "SELECT count(a) FROM test_myjson_with_part where d1 = 1 and col1='abc'"), Row(4)) checkAnswer(sql( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala index 2864181cf9..713d1da1cb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala @@ -21,10 +21,10 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.SQLContext private[json] trait TestJsonData { - protected def _sqlContext: SQLContext + protected def sqlContext: SQLContext def primitiveFieldAndType: RDD[String] = - _sqlContext.sparkContext.parallelize( + sqlContext.sparkContext.parallelize( """{"string":"this is a simple string.", "integer":10, "long":21474836470, @@ -35,7 +35,7 @@ private[json] trait TestJsonData { }""" :: Nil) def primitiveFieldValueTypeConflict: RDD[String] = - _sqlContext.sparkContext.parallelize( + sqlContext.sparkContext.parallelize( """{"num_num_1":11, "num_num_2":null, "num_num_3": 1.1, "num_bool":true, "num_str":13.1, "str_bool":"str1"}""" :: """{"num_num_1":null, "num_num_2":21474836470.9, "num_num_3": null, @@ -46,14 +46,14 @@ private[json] trait TestJsonData { "num_bool":null, "num_str":92233720368547758070, "str_bool":null}""" :: Nil) def jsonNullStruct: RDD[String] = - _sqlContext.sparkContext.parallelize( + sqlContext.sparkContext.parallelize( """{"nullstr":"","ip":"27.31.100.29","headers":{"Host":"1.abc.com","Charset":"UTF-8"}}""" :: """{"nullstr":"","ip":"27.31.100.29","headers":{}}""" :: """{"nullstr":"","ip":"27.31.100.29","headers":""}""" :: """{"nullstr":null,"ip":"27.31.100.29","headers":null}""" :: Nil) def complexFieldValueTypeConflict: RDD[String] = - _sqlContext.sparkContext.parallelize( + sqlContext.sparkContext.parallelize( """{"num_struct":11, "str_array":[1, 2, 3], "array":[], "struct_array":[], "struct": {}}""" :: """{"num_struct":{"field":false}, "str_array":null, @@ -64,14 +64,14 @@ private[json] trait TestJsonData { "array":[7], "struct_array":{"field": true}, "struct": {"field": "str"}}""" :: Nil) def arrayElementTypeConflict: RDD[String] = - _sqlContext.sparkContext.parallelize( + sqlContext.sparkContext.parallelize( """{"array1": [1, 1.1, true, null, [], {}, [2,3,4], {"field":"str"}], "array2": [{"field":214748364700}, {"field":1}]}""" :: """{"array3": [{"field":"str"}, {"field":1}]}""" :: """{"array3": [1, 2, 3]}""" :: Nil) def missingFields: RDD[String] = - _sqlContext.sparkContext.parallelize( + sqlContext.sparkContext.parallelize( """{"a":true}""" :: """{"b":21474836470}""" :: """{"c":[33, 44]}""" :: @@ -79,7 +79,7 @@ private[json] trait TestJsonData { """{"e":"str"}""" :: Nil) def complexFieldAndType1: RDD[String] = - _sqlContext.sparkContext.parallelize( + sqlContext.sparkContext.parallelize( """{"struct":{"field1": true, "field2": 92233720368547758070}, "structWithArrayFields":{"field1":[4, 5, 6], "field2":["str1", "str2"]}, "arrayOfString":["str1", "str2"], @@ -95,7 +95,7 @@ private[json] trait TestJsonData { }""" :: Nil) def complexFieldAndType2: RDD[String] = - _sqlContext.sparkContext.parallelize( + sqlContext.sparkContext.parallelize( """{"arrayOfStruct":[{"field1": true, "field2": "str1"}, {"field1": false}, {"field3": null}], "complexArrayOfStruct": [ { @@ -149,7 +149,7 @@ private[json] trait TestJsonData { }""" :: Nil) def mapType1: RDD[String] = - _sqlContext.sparkContext.parallelize( + sqlContext.sparkContext.parallelize( """{"map": {"a": 1}}""" :: """{"map": {"b": 2}}""" :: """{"map": {"c": 3}}""" :: @@ -157,7 +157,7 @@ private[json] trait TestJsonData { """{"map": {"e": null}}""" :: Nil) def mapType2: RDD[String] = - _sqlContext.sparkContext.parallelize( + sqlContext.sparkContext.parallelize( """{"map": {"a": {"field1": [1, 2, 3, null]}}}""" :: """{"map": {"b": {"field2": 2}}}""" :: """{"map": {"c": {"field1": [], "field2": 4}}}""" :: @@ -166,21 +166,21 @@ private[json] trait TestJsonData { """{"map": {"f": {"field1": null}}}""" :: Nil) def nullsInArrays: RDD[String] = - _sqlContext.sparkContext.parallelize( + sqlContext.sparkContext.parallelize( """{"field1":[[null], [[["Test"]]]]}""" :: """{"field2":[null, [{"Test":1}]]}""" :: """{"field3":[[null], [{"Test":"2"}]]}""" :: """{"field4":[[null, [1,2,3]]]}""" :: Nil) def jsonArray: RDD[String] = - _sqlContext.sparkContext.parallelize( + sqlContext.sparkContext.parallelize( """[{"a":"str_a_1"}]""" :: """[{"a":"str_a_2"}, {"b":"str_b_3"}]""" :: """{"b":"str_b_4", "a":"str_a_4", "c":"str_c_4"}""" :: """[]""" :: Nil) def corruptRecords: RDD[String] = - _sqlContext.sparkContext.parallelize( + sqlContext.sparkContext.parallelize( """{""" :: """""" :: """{"a":1, b:2}""" :: @@ -189,7 +189,7 @@ private[json] trait TestJsonData { """]""" :: Nil) def emptyRecords: RDD[String] = - _sqlContext.sparkContext.parallelize( + sqlContext.sparkContext.parallelize( """{""" :: """""" :: """{"a": {}}""" :: @@ -198,7 +198,7 @@ private[json] trait TestJsonData { """]""" :: Nil) - lazy val singleRow: RDD[String] = _sqlContext.sparkContext.parallelize("""{"a":123}""" :: Nil) + lazy val singleRow: RDD[String] = sqlContext.sparkContext.parallelize("""{"a":123}""" :: Nil) - def empty: RDD[String] = _sqlContext.sparkContext.parallelize(Seq[String]()) + def empty: RDD[String] = sqlContext.sparkContext.parallelize(Seq[String]()) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala index 91f3ce4d34..0835bd1230 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala @@ -39,12 +39,13 @@ private[sql] abstract class ParquetCompatibilityTest extends QueryTest with Parq protected def readParquetSchema(path: String, pathFilter: Path => Boolean): MessageType = { val fsPath = new Path(path) - val fs = fsPath.getFileSystem(configuration) + val fs = fsPath.getFileSystem(hadoopConfiguration) val parquetFiles = fs.listStatus(fsPath, new PathFilter { override def accept(path: Path): Boolean = pathFilter(path) }).toSeq.asJava - val footers = ParquetFileReader.readAllFootersInParallel(configuration, parquetFiles, true) + val footers = + ParquetFileReader.readAllFootersInParallel(hadoopConfiguration, parquetFiles, true) footers.asScala.head.getParquetMetadata.getFileMetaData.getSchema } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index 08d2b9dee9..cd552e8337 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -101,7 +101,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { test("fixed-length decimals") { def makeDecimalRDD(decimal: DecimalType): DataFrame = - sqlContext.sparkContext + sparkContext .parallelize(0 to 1000) .map(i => Tuple1(i / 100.0)) .toDF() @@ -119,7 +119,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { test("date type") { def makeDateRDD(): DataFrame = - sqlContext.sparkContext + sparkContext .parallelize(0 to 1000) .map(i => Tuple1(DateTimeUtils.toJavaDate(i))) .toDF() @@ -207,7 +207,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { test("compression codec") { def compressionCodecFor(path: String): String = { val codecs = ParquetTypesConverter - .readMetaData(new Path(path), Some(configuration)).getBlocks.asScala + .readMetaData(new Path(path), Some(hadoopConfiguration)).getBlocks.asScala .flatMap(_.getColumns.asScala) .map(_.getCodec.name()) .distinct @@ -277,14 +277,14 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { test("write metadata") { withTempPath { file => val path = new Path(file.toURI.toString) - val fs = FileSystem.getLocal(configuration) + val fs = FileSystem.getLocal(hadoopConfiguration) val attributes = ScalaReflection.attributesFor[(Int, String)] - ParquetTypesConverter.writeMetaData(attributes, path, configuration) + ParquetTypesConverter.writeMetaData(attributes, path, hadoopConfiguration) assert(fs.exists(new Path(path, ParquetFileWriter.PARQUET_COMMON_METADATA_FILE))) assert(fs.exists(new Path(path, ParquetFileWriter.PARQUET_METADATA_FILE))) - val metaData = ParquetTypesConverter.readMetaData(path, Some(configuration)) + val metaData = ParquetTypesConverter.readMetaData(path, Some(hadoopConfiguration)) val actualSchema = metaData.getFileMetaData.getSchema val expectedSchema = ParquetTypesConverter.convertFromAttributes(attributes) @@ -355,7 +355,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { val path = new Path(location.getCanonicalPath) ParquetFileWriter.writeMetadataFile( - sqlContext.sparkContext.hadoopConfiguration, + sparkContext.hadoopConfiguration, path, Collections.singletonList( new Footer(path, new ParquetMetadata(fileMetadata, Collections.emptyList())))) @@ -370,12 +370,12 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { } test("SPARK-6352 DirectParquetOutputCommitter") { - val clonedConf = new Configuration(configuration) + val clonedConf = new Configuration(hadoopConfiguration) // Write to a parquet file and let it fail. // _temporary should be missing if direct output committer works. try { - configuration.set("spark.sql.parquet.output.committer.class", + hadoopConfiguration.set("spark.sql.parquet.output.committer.class", classOf[DirectParquetOutputCommitter].getCanonicalName) sqlContext.udf.register("div0", (x: Int) => x / 0) withTempPath { dir => @@ -383,23 +383,23 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { sqlContext.sql("select div0(1)").write.parquet(dir.getCanonicalPath) } val path = new Path(dir.getCanonicalPath, "_temporary") - val fs = path.getFileSystem(configuration) + val fs = path.getFileSystem(hadoopConfiguration) assert(!fs.exists(path)) } } finally { // Hadoop 1 doesn't have `Configuration.unset` - configuration.clear() - clonedConf.asScala.foreach(entry => configuration.set(entry.getKey, entry.getValue)) + hadoopConfiguration.clear() + clonedConf.asScala.foreach(entry => hadoopConfiguration.set(entry.getKey, entry.getValue)) } } test("SPARK-9849 DirectParquetOutputCommitter qualified name should be backward compatible") { - val clonedConf = new Configuration(configuration) + val clonedConf = new Configuration(hadoopConfiguration) // Write to a parquet file and let it fail. // _temporary should be missing if direct output committer works. try { - configuration.set("spark.sql.parquet.output.committer.class", + hadoopConfiguration.set("spark.sql.parquet.output.committer.class", "org.apache.spark.sql.parquet.DirectParquetOutputCommitter") sqlContext.udf.register("div0", (x: Int) => x / 0) withTempPath { dir => @@ -407,25 +407,25 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { sqlContext.sql("select div0(1)").write.parquet(dir.getCanonicalPath) } val path = new Path(dir.getCanonicalPath, "_temporary") - val fs = path.getFileSystem(configuration) + val fs = path.getFileSystem(hadoopConfiguration) assert(!fs.exists(path)) } } finally { // Hadoop 1 doesn't have `Configuration.unset` - configuration.clear() - clonedConf.asScala.foreach(entry => configuration.set(entry.getKey, entry.getValue)) + hadoopConfiguration.clear() + clonedConf.asScala.foreach(entry => hadoopConfiguration.set(entry.getKey, entry.getValue)) } } test("SPARK-8121: spark.sql.parquet.output.committer.class shouldn't be overridden") { withTempPath { dir => - val clonedConf = new Configuration(configuration) + val clonedConf = new Configuration(hadoopConfiguration) - configuration.set( + hadoopConfiguration.set( SQLConf.OUTPUT_COMMITTER_CLASS.key, classOf[ParquetOutputCommitter].getCanonicalName) - configuration.set( + hadoopConfiguration.set( "spark.sql.parquet.output.committer.class", classOf[JobCommitFailureParquetOutputCommitter].getCanonicalName) @@ -436,8 +436,8 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { assert(message === "Intentional exception for testing purposes") } finally { // Hadoop 1 doesn't have `Configuration.unset` - configuration.clear() - clonedConf.asScala.foreach(entry => configuration.set(entry.getKey, entry.getValue)) + hadoopConfiguration.clear() + clonedConf.asScala.foreach(entry => hadoopConfiguration.set(entry.getKey, entry.getValue)) } } } @@ -455,11 +455,11 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { } test("SPARK-7837 Do not close output writer twice when commitTask() fails") { - val clonedConf = new Configuration(configuration) + val clonedConf = new Configuration(hadoopConfiguration) // Using a output committer that always fail when committing a task, so that both // `commitTask()` and `abortTask()` are invoked. - configuration.set( + hadoopConfiguration.set( "spark.sql.parquet.output.committer.class", classOf[TaskCommitFailureParquetOutputCommitter].getCanonicalName) @@ -483,8 +483,8 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { } } finally { // Hadoop 1 doesn't have `Configuration.unset` - configuration.clear() - clonedConf.asScala.foreach(entry => configuration.set(entry.getKey, entry.getValue)) + hadoopConfiguration.clear() + clonedConf.asScala.foreach(entry => hadoopConfiguration.set(entry.getKey, entry.getValue)) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index ed8bafb10c..7bac8609e1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -517,7 +517,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha } val schema = StructType(partitionColumns :+ StructField(s"i", StringType)) - val df = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(row :: Nil), schema) + val df = sqlContext.createDataFrame(sparkContext.parallelize(row :: Nil), schema) withTempPath { dir => df.write.format("parquet").partitionBy(partitionColumns.map(_.name): _*).save(dir.toString) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index a379523d67..9edbb52268 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -30,6 +30,7 @@ import org.apache.spark.util.Utils * A test suite that tests various Parquet queries. */ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext { + import testImplicits._ test("simple select queries") { withParquetTable((0 until 10).map(i => (i, i.toString)), "t") { @@ -40,22 +41,22 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext test("appending") { val data = (0 until 10).map(i => (i, i.toString)) - ctx.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") + sqlContext.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") withParquetTable(data, "t") { sql("INSERT INTO TABLE t SELECT * FROM tmp") - checkAnswer(ctx.table("t"), (data ++ data).map(Row.fromTuple)) + checkAnswer(sqlContext.table("t"), (data ++ data).map(Row.fromTuple)) } - ctx.catalog.unregisterTable(Seq("tmp")) + sqlContext.catalog.unregisterTable(Seq("tmp")) } test("overwriting") { val data = (0 until 10).map(i => (i, i.toString)) - ctx.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") + sqlContext.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") withParquetTable(data, "t") { sql("INSERT OVERWRITE TABLE t SELECT * FROM tmp") - checkAnswer(ctx.table("t"), data.map(Row.fromTuple)) + checkAnswer(sqlContext.table("t"), data.map(Row.fromTuple)) } - ctx.catalog.unregisterTable(Seq("tmp")) + sqlContext.catalog.unregisterTable(Seq("tmp")) } test("self-join") { @@ -118,9 +119,9 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext val schema = StructType(List(StructField("d", DecimalType(18, 0), false), StructField("time", TimestampType, false)).toArray) withTempPath { file => - val df = ctx.createDataFrame(ctx.sparkContext.parallelize(data), schema) + val df = sqlContext.createDataFrame(sparkContext.parallelize(data), schema) df.write.parquet(file.getCanonicalPath) - val df2 = ctx.read.parquet(file.getCanonicalPath) + val df2 = sqlContext.read.parquet(file.getCanonicalPath) checkAnswer(df2, df.collect().toSeq) } } @@ -129,12 +130,12 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext def testSchemaMerging(expectedColumnNumber: Int): Unit = { withTempDir { dir => val basePath = dir.getCanonicalPath - ctx.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) - ctx.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString) + sqlContext.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) + sqlContext.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString) // delete summary files, so if we don't merge part-files, one column will not be included. Utils.deleteRecursively(new File(basePath + "/foo=1/_metadata")) Utils.deleteRecursively(new File(basePath + "/foo=1/_common_metadata")) - assert(ctx.read.parquet(basePath).columns.length === expectedColumnNumber) + assert(sqlContext.read.parquet(basePath).columns.length === expectedColumnNumber) } } @@ -153,9 +154,9 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext def testSchemaMerging(expectedColumnNumber: Int): Unit = { withTempDir { dir => val basePath = dir.getCanonicalPath - ctx.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) - ctx.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString) - assert(ctx.read.parquet(basePath).columns.length === expectedColumnNumber) + sqlContext.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) + sqlContext.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString) + assert(sqlContext.read.parquet(basePath).columns.length === expectedColumnNumber) } } @@ -171,19 +172,19 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext test("SPARK-8990 DataFrameReader.parquet() should respect user specified options") { withTempPath { dir => val basePath = dir.getCanonicalPath - ctx.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) - ctx.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=a").toString) + sqlContext.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) + sqlContext.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=a").toString) // Disables the global SQL option for schema merging withSQLConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "false") { assertResult(2) { // Disables schema merging via data source option - ctx.read.option("mergeSchema", "false").parquet(basePath).columns.length + sqlContext.read.option("mergeSchema", "false").parquet(basePath).columns.length } assertResult(3) { // Enables schema merging via data source option - ctx.read.option("mergeSchema", "true").parquet(basePath).columns.length + sqlContext.read.option("mergeSchema", "true").parquet(basePath).columns.length } } } @@ -193,7 +194,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext withTempPath { dir => val basePath = dir.getCanonicalPath val schema = StructType(Array(StructField("name", DecimalType(10, 5), false))) - val rowRDD = sqlContext.sparkContext.parallelize(Array(Row(Decimal("67123.45")))) + val rowRDD = sparkContext.parallelize(Array(Row(Decimal("67123.45")))) val df = sqlContext.createDataFrame(rowRDD, schema) df.write.parquet(basePath) @@ -203,9 +204,6 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext } test("SPARK-10005 Schema merging for nested struct") { - val sqlContext = _sqlContext - import sqlContext.implicits._ - withTempPath { dir => val path = dir.getCanonicalPath diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala index 5dbc7d1630..442fafb12f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala @@ -33,7 +33,6 @@ import org.apache.spark.sql.{DataFrame, SaveMode, SQLContext} * Especially, `Tuple1.apply` can be used to easily wrap a single type/value. */ private[sql] trait ParquetTest extends SQLTestUtils { - protected def _sqlContext: SQLContext /** * Writes `data` to a Parquet file, which is then passed to `f` and will be deleted after `f` @@ -43,7 +42,7 @@ private[sql] trait ParquetTest extends SQLTestUtils { (data: Seq[T]) (f: String => Unit): Unit = { withTempPath { file => - _sqlContext.createDataFrame(data).write.parquet(file.getCanonicalPath) + sqlContext.createDataFrame(data).write.parquet(file.getCanonicalPath) f(file.getCanonicalPath) } } @@ -55,7 +54,7 @@ private[sql] trait ParquetTest extends SQLTestUtils { protected def withParquetDataFrame[T <: Product: ClassTag: TypeTag] (data: Seq[T]) (f: DataFrame => Unit): Unit = { - withParquetFile(data)(path => f(_sqlContext.read.parquet(path))) + withParquetFile(data)(path => f(sqlContext.read.parquet(path))) } /** @@ -67,14 +66,14 @@ private[sql] trait ParquetTest extends SQLTestUtils { (data: Seq[T], tableName: String) (f: => Unit): Unit = { withParquetDataFrame(data) { df => - _sqlContext.registerDataFrameAsTable(df, tableName) + sqlContext.registerDataFrameAsTable(df, tableName) withTempTable(tableName)(f) } } protected def makeParquetFile[T <: Product: ClassTag: TypeTag]( data: Seq[T], path: File): Unit = { - _sqlContext.createDataFrame(data).write.mode(SaveMode.Overwrite).parquet(path.getCanonicalPath) + sqlContext.createDataFrame(data).write.mode(SaveMode.Overwrite).parquet(path.getCanonicalPath) } protected def makeParquetFile[T <: Product: ClassTag: TypeTag]( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index 53a0e53fd7..dcbfdca71a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -33,8 +33,7 @@ import org.apache.spark.sql.{SQLConf, SQLContext, QueryTest} * without serializing the hashed relation, which does not happen in local mode. */ class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll { - private var sc: SparkContext = null - private var sqlContext: SQLContext = null + protected var sqlContext: SQLContext = null /** * Create a new [[SQLContext]] running in local-cluster mode with unsafe and codegen enabled. @@ -44,15 +43,14 @@ class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll { val conf = new SparkConf() .setMaster("local-cluster[2,1,1024]") .setAppName("testing") - sc = new SparkContext(conf) + val sc = new SparkContext(conf) sqlContext = new SQLContext(sc) sqlContext.setConf(SQLConf.UNSAFE_ENABLED, true) sqlContext.setConf(SQLConf.CODEGEN_ENABLED, true) } override def afterAll(): Unit = { - sc.stop() - sc = null + sqlContext.sparkContext.stop() sqlContext = null } @@ -60,7 +58,7 @@ class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll { * Test whether the specified broadcast join updates the peak execution memory accumulator. */ private def testBroadcastJoin[T: ClassTag](name: String, joinType: String): Unit = { - AccumulatorSuite.verifyPeakExecutionMemorySet(sc, name) { + AccumulatorSuite.verifyPeakExecutionMemorySet(sqlContext.sparkContext, name) { val df1 = sqlContext.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") val df2 = sqlContext.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value") // Comparison at the end is for broadcast left semi join diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index 4c9187a9a7..e5fd9e277f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -37,7 +37,7 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { test("GeneralHashedRelation") { val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2)) - val numDataRows = SQLMetrics.createLongMetric(ctx.sparkContext, "data") + val numDataRows = SQLMetrics.createLongMetric(sparkContext, "data") val hashed = HashedRelation(data.iterator, numDataRows, keyProjection) assert(hashed.isInstanceOf[GeneralHashedRelation]) @@ -53,7 +53,7 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { test("UniqueKeyHashedRelation") { val data = Array(InternalRow(0), InternalRow(1), InternalRow(2)) - val numDataRows = SQLMetrics.createLongMetric(ctx.sparkContext, "data") + val numDataRows = SQLMetrics.createLongMetric(sparkContext, "data") val hashed = HashedRelation(data.iterator, numDataRows, keyProjection) assert(hashed.isInstanceOf[UniqueKeyHashedRelation]) @@ -73,7 +73,7 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { test("UnsafeHashedRelation") { val schema = StructType(StructField("a", IntegerType, true) :: Nil) val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2)) - val numDataRows = SQLMetrics.createLongMetric(ctx.sparkContext, "data") + val numDataRows = SQLMetrics.createLongMetric(sparkContext, "data") val toUnsafe = UnsafeProjection.create(schema) val unsafeData = data.map(toUnsafe(_).copy()).toArray diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index cc649b9bd4..4174ee0550 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -27,9 +27,10 @@ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StringType, StructType} class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { + import testImplicits.localSeqToDataFrameHolder - private lazy val myUpperCaseData = ctx.createDataFrame( - ctx.sparkContext.parallelize(Seq( + private lazy val myUpperCaseData = sqlContext.createDataFrame( + sparkContext.parallelize(Seq( Row(1, "A"), Row(2, "B"), Row(3, "C"), @@ -39,8 +40,8 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { Row(null, "G") )), new StructType().add("N", IntegerType).add("L", StringType)) - private lazy val myLowerCaseData = ctx.createDataFrame( - ctx.sparkContext.parallelize(Seq( + private lazy val myLowerCaseData = sqlContext.createDataFrame( + sparkContext.parallelize(Seq( Row(1, "a"), Row(2, "b"), Row(3, "c"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index a1a617d7b7..c2e0bdac17 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -28,8 +28,8 @@ import org.apache.spark.sql.types.{IntegerType, DoubleType, StructType} class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { - private lazy val left = ctx.createDataFrame( - ctx.sparkContext.parallelize(Seq( + private lazy val left = sqlContext.createDataFrame( + sparkContext.parallelize(Seq( Row(1, 2.0), Row(2, 100.0), Row(2, 1.0), // This row is duplicated to ensure that we will have multiple buffered matches @@ -40,8 +40,8 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { Row(null, null) )), new StructType().add("a", IntegerType).add("b", DoubleType)) - private lazy val right = ctx.createDataFrame( - ctx.sparkContext.parallelize(Seq( + private lazy val right = sqlContext.createDataFrame( + sparkContext.parallelize(Seq( Row(0, 0.0), Row(2, 3.0), // This row is duplicated to ensure that we will have multiple buffered matches Row(2, -1.0), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala index baa86e320d..3afd762942 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala @@ -28,8 +28,8 @@ import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType} class SemiJoinSuite extends SparkPlanTest with SharedSQLContext { - private lazy val left = ctx.createDataFrame( - ctx.sparkContext.parallelize(Seq( + private lazy val left = sqlContext.createDataFrame( + sparkContext.parallelize(Seq( Row(1, 2.0), Row(1, 2.0), Row(2, 1.0), @@ -40,8 +40,8 @@ class SemiJoinSuite extends SparkPlanTest with SharedSQLContext { Row(6, null) )), new StructType().add("a", IntegerType).add("b", DoubleType)) - private lazy val right = ctx.createDataFrame( - ctx.sparkContext.parallelize(Seq( + private lazy val right = sqlContext.createDataFrame( + sparkContext.parallelize(Seq( Row(2, 3.0), Row(2, 3.0), Row(3, 2.0), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 80006bf077..6afffae161 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -36,7 +36,7 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { import testImplicits._ test("LongSQLMetric should not box Long") { - val l = SQLMetrics.createLongMetric(ctx.sparkContext, "long") + val l = SQLMetrics.createLongMetric(sparkContext, "long") val f = () => { l += 1L l.add(1L) @@ -50,7 +50,7 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { test("Normal accumulator should do boxing") { // We need this test to make sure BoxingFinder works. - val l = ctx.sparkContext.accumulator(0L) + val l = sparkContext.accumulator(0L) val f = () => { l += 1L } BoxingFinder.getClassReader(f.getClass).foreach { cl => val boxingFinder = new BoxingFinder() @@ -71,19 +71,19 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { df: DataFrame, expectedNumOfJobs: Int, expectedMetrics: Map[Long, (String, Map[String, Any])]): Unit = { - val previousExecutionIds = ctx.listener.executionIdToData.keySet + val previousExecutionIds = sqlContext.listener.executionIdToData.keySet df.collect() - ctx.sparkContext.listenerBus.waitUntilEmpty(10000) - val executionIds = ctx.listener.executionIdToData.keySet.diff(previousExecutionIds) + sparkContext.listenerBus.waitUntilEmpty(10000) + val executionIds = sqlContext.listener.executionIdToData.keySet.diff(previousExecutionIds) assert(executionIds.size === 1) val executionId = executionIds.head - val jobs = ctx.listener.getExecution(executionId).get.jobs + val jobs = sqlContext.listener.getExecution(executionId).get.jobs // Use "<=" because there is a race condition that we may miss some jobs // TODO Change it to "=" once we fix the race condition that missing the JobStarted event. assert(jobs.size <= expectedNumOfJobs) if (jobs.size == expectedNumOfJobs) { // If we can track all jobs, check the metric values - val metricValues = ctx.listener.getExecutionMetrics(executionId) + val metricValues = sqlContext.listener.getExecutionMetrics(executionId) val actualMetrics = SparkPlanGraph(df.queryExecution.executedPlan).nodes.filter { node => expectedMetrics.contains(node.id) }.map { node => @@ -474,19 +474,19 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { test("save metrics") { withTempPath { file => - val previousExecutionIds = ctx.listener.executionIdToData.keySet + val previousExecutionIds = sqlContext.listener.executionIdToData.keySet // Assume the execution plan is // PhysicalRDD(nodeId = 0) person.select('name).write.format("json").save(file.getAbsolutePath) - ctx.sparkContext.listenerBus.waitUntilEmpty(10000) - val executionIds = ctx.listener.executionIdToData.keySet.diff(previousExecutionIds) + sparkContext.listenerBus.waitUntilEmpty(10000) + val executionIds = sqlContext.listener.executionIdToData.keySet.diff(previousExecutionIds) assert(executionIds.size === 1) val executionId = executionIds.head - val jobs = ctx.listener.getExecution(executionId).get.jobs + val jobs = sqlContext.listener.getExecution(executionId).get.jobs // Use "<=" because there is a race condition that we may miss some jobs // TODO Change "<=" to "=" once we fix the race condition that missing the JobStarted event. assert(jobs.size <= 1) - val metricValues = ctx.listener.getExecutionMetrics(executionId) + val metricValues = sqlContext.listener.getExecutionMetrics(executionId) // Because "save" will create a new DataFrame internally, we cannot get the real metric id. // However, we still can check the value. assert(metricValues.values.toSeq === Seq(2L)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala index 80d1e88956..2bbb41ca77 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala @@ -74,7 +74,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { } test("basic") { - val listener = new SQLListener(ctx) + val listener = new SQLListener(sqlContext) val executionId = 0 val df = createTestDataFrame val accumulatorIds = @@ -212,7 +212,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { } test("onExecutionEnd happens before onJobEnd(JobSucceeded)") { - val listener = new SQLListener(ctx) + val listener = new SQLListener(sqlContext) val executionId = 0 val df = createTestDataFrame listener.onExecutionStart( @@ -241,7 +241,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { } test("onExecutionEnd happens before multiple onJobEnd(JobSucceeded)s") { - val listener = new SQLListener(ctx) + val listener = new SQLListener(sqlContext) val executionId = 0 val df = createTestDataFrame listener.onExecutionStart( @@ -281,7 +281,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { } test("onExecutionEnd happens before onJobEnd(JobFailed)") { - val listener = new SQLListener(ctx) + val listener = new SQLListener(sqlContext) val executionId = 0 val df = createTestDataFrame listener.onExecutionStart( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index d8c9a08d84..ed710689cc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -255,26 +255,26 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext } test("Basic API") { - assert(ctx.read.jdbc( + assert(sqlContext.read.jdbc( urlWithUserAndPass, "TEST.PEOPLE", new Properties).collect().length === 3) } test("Basic API with FetchSize") { val properties = new Properties properties.setProperty("fetchSize", "2") - assert(ctx.read.jdbc( + assert(sqlContext.read.jdbc( urlWithUserAndPass, "TEST.PEOPLE", properties).collect().length === 3) } test("Partitioning via JDBCPartitioningInfo API") { assert( - ctx.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", "THEID", 0, 4, 3, new Properties) + sqlContext.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", "THEID", 0, 4, 3, new Properties) .collect().length === 3) } test("Partitioning via list-of-where-clauses API") { val parts = Array[String]("THEID < 2", "THEID >= 2") - assert(ctx.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", parts, new Properties) + assert(sqlContext.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", parts, new Properties) .collect().length === 3) } @@ -330,9 +330,9 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext } test("test DATE types") { - val rows = ctx.read.jdbc( + val rows = sqlContext.read.jdbc( urlWithUserAndPass, "TEST.TIMETYPES", new Properties).collect() - val cachedRows = ctx.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties) + val cachedRows = sqlContext.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties) .cache().collect() assert(rows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01")) assert(rows(1).getAs[java.sql.Date](1) === null) @@ -340,8 +340,8 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext } test("test DATE types in cache") { - val rows = ctx.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties).collect() - ctx.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties) + val rows = sqlContext.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties).collect() + sqlContext.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties) .cache().registerTempTable("mycached_date") val cachedRows = sql("select * from mycached_date").collect() assert(rows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01")) @@ -349,7 +349,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext } test("test types for null value") { - val rows = ctx.read.jdbc( + val rows = sqlContext.read.jdbc( urlWithUserAndPass, "TEST.NULLTYPES", new Properties).collect() assert((0 to 14).forall(i => rows(0).isNullAt(i))) } @@ -396,7 +396,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext test("Remap types via JdbcDialects") { JdbcDialects.registerDialect(testH2Dialect) - val df = ctx.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", new Properties) + val df = sqlContext.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", new Properties) assert(df.schema.filter(_.dataType != org.apache.spark.sql.types.StringType).isEmpty) val rows = df.collect() assert(rows(0).get(0).isInstanceOf[String]) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index 5dc3a2c07b..e23ee66931 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -22,13 +22,12 @@ import java.util.Properties import org.scalatest.BeforeAndAfter -import org.apache.spark.SparkFunSuite import org.apache.spark.sql.{Row, SaveMode} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext { +class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { val url = "jdbc:h2:mem:testdb2" var conn: java.sql.Connection = null @@ -76,8 +75,6 @@ class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLCon conn1.close() } - private lazy val sc = ctx.sparkContext - private lazy val arr2x2 = Array[Row](Row.apply("dave", 42), Row.apply("mary", 222)) private lazy val arr1x2 = Array[Row](Row.apply("fred", 3)) private lazy val schema2 = StructType( @@ -91,49 +88,50 @@ class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLCon StructField("seq", IntegerType) :: Nil) test("Basic CREATE") { - val df = ctx.createDataFrame(sc.parallelize(arr2x2), schema2) + val df = sqlContext.createDataFrame(sparkContext.parallelize(arr2x2), schema2) df.write.jdbc(url, "TEST.BASICCREATETEST", new Properties) - assert(2 === ctx.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).count) - assert(2 === ctx.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).collect()(0).length) + assert(2 === sqlContext.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).count) + assert( + 2 === sqlContext.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).collect()(0).length) } test("CREATE with overwrite") { - val df = ctx.createDataFrame(sc.parallelize(arr2x3), schema3) - val df2 = ctx.createDataFrame(sc.parallelize(arr1x2), schema2) + val df = sqlContext.createDataFrame(sparkContext.parallelize(arr2x3), schema3) + val df2 = sqlContext.createDataFrame(sparkContext.parallelize(arr1x2), schema2) df.write.jdbc(url1, "TEST.DROPTEST", properties) - assert(2 === ctx.read.jdbc(url1, "TEST.DROPTEST", properties).count) - assert(3 === ctx.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length) + assert(2 === sqlContext.read.jdbc(url1, "TEST.DROPTEST", properties).count) + assert(3 === sqlContext.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length) df2.write.mode(SaveMode.Overwrite).jdbc(url1, "TEST.DROPTEST", properties) - assert(1 === ctx.read.jdbc(url1, "TEST.DROPTEST", properties).count) - assert(2 === ctx.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length) + assert(1 === sqlContext.read.jdbc(url1, "TEST.DROPTEST", properties).count) + assert(2 === sqlContext.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length) } test("CREATE then INSERT to append") { - val df = ctx.createDataFrame(sc.parallelize(arr2x2), schema2) - val df2 = ctx.createDataFrame(sc.parallelize(arr1x2), schema2) + val df = sqlContext.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + val df2 = sqlContext.createDataFrame(sparkContext.parallelize(arr1x2), schema2) df.write.jdbc(url, "TEST.APPENDTEST", new Properties) df2.write.mode(SaveMode.Append).jdbc(url, "TEST.APPENDTEST", new Properties) - assert(3 === ctx.read.jdbc(url, "TEST.APPENDTEST", new Properties).count) - assert(2 === ctx.read.jdbc(url, "TEST.APPENDTEST", new Properties).collect()(0).length) + assert(3 === sqlContext.read.jdbc(url, "TEST.APPENDTEST", new Properties).count) + assert(2 === sqlContext.read.jdbc(url, "TEST.APPENDTEST", new Properties).collect()(0).length) } test("CREATE then INSERT to truncate") { - val df = ctx.createDataFrame(sc.parallelize(arr2x2), schema2) - val df2 = ctx.createDataFrame(sc.parallelize(arr1x2), schema2) + val df = sqlContext.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + val df2 = sqlContext.createDataFrame(sparkContext.parallelize(arr1x2), schema2) df.write.jdbc(url1, "TEST.TRUNCATETEST", properties) df2.write.mode(SaveMode.Overwrite).jdbc(url1, "TEST.TRUNCATETEST", properties) - assert(1 === ctx.read.jdbc(url1, "TEST.TRUNCATETEST", properties).count) - assert(2 === ctx.read.jdbc(url1, "TEST.TRUNCATETEST", properties).collect()(0).length) + assert(1 === sqlContext.read.jdbc(url1, "TEST.TRUNCATETEST", properties).count) + assert(2 === sqlContext.read.jdbc(url1, "TEST.TRUNCATETEST", properties).collect()(0).length) } test("Incompatible INSERT to append") { - val df = ctx.createDataFrame(sc.parallelize(arr2x2), schema2) - val df2 = ctx.createDataFrame(sc.parallelize(arr2x3), schema3) + val df = sqlContext.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + val df2 = sqlContext.createDataFrame(sparkContext.parallelize(arr2x3), schema3) df.write.jdbc(url, "TEST.INCOMPATIBLETEST", new Properties) intercept[org.apache.spark.SparkException] { @@ -143,14 +141,14 @@ class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLCon test("INSERT to JDBC Datasource") { sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE") - assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).count) - assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) + assert(2 === sqlContext.read.jdbc(url1, "TEST.PEOPLE1", properties).count) + assert(2 === sqlContext.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) } test("INSERT to JDBC Datasource with overwrite") { sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE") sql("INSERT OVERWRITE TABLE PEOPLE1 SELECT * FROM PEOPLE") - assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).count) - assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) + assert(2 === sqlContext.read.jdbc(url1, "TEST.PEOPLE1", properties).count) + assert(2 === sqlContext.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala index 9bc3f6bcf6..6fc9febe49 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala @@ -26,10 +26,8 @@ import org.apache.spark.sql.execution.datasources.DDLException import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.Utils - class CreateTableAsSelectSuite extends DataSourceTest with SharedSQLContext with BeforeAndAfter { protected override lazy val sql = caseInsensitiveContext.sql _ - private lazy val sparkContext = caseInsensitiveContext.sparkContext private var path: File = null override def beforeAll(): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala index d74d29fb0b..af04079ec8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala @@ -19,13 +19,11 @@ package org.apache.spark.sql.sources import org.apache.spark.sql._ - private[sql] abstract class DataSourceTest extends QueryTest { - protected def _sqlContext: SQLContext // We want to test some edge cases. protected lazy val caseInsensitiveContext: SQLContext = { - val ctx = new SQLContext(_sqlContext.sparkContext) + val ctx = new SQLContext(sqlContext.sparkContext) ctx.setConf(SQLConf.CASE_SENSITIVE, false) ctx } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index 084d83f6e9..5b70d258d6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -19,13 +19,12 @@ package org.apache.spark.sql.sources import java.io.File -import org.apache.spark.sql.{SaveMode, AnalysisException, Row} +import org.apache.spark.sql.{AnalysisException, Row} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.Utils class InsertSuite extends DataSourceTest with SharedSQLContext { protected override lazy val sql = caseInsensitiveContext.sql _ - private lazy val sparkContext = caseInsensitiveContext.sparkContext private var path: File = null override def beforeAll(): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala index 79b6e9b45c..c9791879ec 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala @@ -29,11 +29,11 @@ class PartitionedWriteSuite extends QueryTest with SharedSQLContext { val path = Utils.createTempDir() path.delete() - val df = ctx.range(100).select($"id", lit(1).as("data")) + val df = sqlContext.range(100).select($"id", lit(1).as("data")) df.write.partitionBy("id").save(path.getCanonicalPath) checkAnswer( - ctx.read.load(path.getCanonicalPath), + sqlContext.read.load(path.getCanonicalPath), (0 to 99).map(Row(1, _)).toSeq) Utils.deleteRecursively(path) @@ -43,12 +43,12 @@ class PartitionedWriteSuite extends QueryTest with SharedSQLContext { val path = Utils.createTempDir() path.delete() - val base = ctx.range(100) + val base = sqlContext.range(100) val df = base.unionAll(base).select($"id", lit(1).as("data")) df.write.partitionBy("id").save(path.getCanonicalPath) checkAnswer( - ctx.read.load(path.getCanonicalPath), + sqlContext.read.load(path.getCanonicalPath), (0 to 99).map(Row(1, _)).toSeq ++ (0 to 99).map(Row(1, _)).toSeq) Utils.deleteRecursively(path) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala index f18546b4c2..10d2613689 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala @@ -28,7 +28,6 @@ import org.apache.spark.util.Utils class SaveLoadSuite extends DataSourceTest with SharedSQLContext with BeforeAndAfter { protected override lazy val sql = caseInsensitiveContext.sql _ - private lazy val sparkContext = caseInsensitiveContext.sparkContext private var originalDefaultSource: String = null private var path: File = null private var df: DataFrame = null diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala index 3fc02df954..520dea7f7d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala @@ -24,11 +24,11 @@ import org.apache.spark.sql.{DataFrame, SQLContext, SQLImplicits} * A collection of sample data used in SQL tests. */ private[sql] trait SQLTestData { self => - protected def _sqlContext: SQLContext + protected def sqlContext: SQLContext // Helper object to import SQL implicits without a concrete SQLContext private object internalImplicits extends SQLImplicits { - protected override def _sqlContext: SQLContext = self._sqlContext + protected override def _sqlContext: SQLContext = self.sqlContext } import internalImplicits._ @@ -37,21 +37,21 @@ private[sql] trait SQLTestData { self => // Note: all test data should be lazy because the SQLContext is not set up yet. protected lazy val emptyTestData: DataFrame = { - val df = _sqlContext.sparkContext.parallelize( + val df = sqlContext.sparkContext.parallelize( Seq.empty[Int].map(i => TestData(i, i.toString))).toDF() df.registerTempTable("emptyTestData") df } protected lazy val testData: DataFrame = { - val df = _sqlContext.sparkContext.parallelize( + val df = sqlContext.sparkContext.parallelize( (1 to 100).map(i => TestData(i, i.toString))).toDF() df.registerTempTable("testData") df } protected lazy val testData2: DataFrame = { - val df = _sqlContext.sparkContext.parallelize( + val df = sqlContext.sparkContext.parallelize( TestData2(1, 1) :: TestData2(1, 2) :: TestData2(2, 1) :: @@ -63,7 +63,7 @@ private[sql] trait SQLTestData { self => } protected lazy val testData3: DataFrame = { - val df = _sqlContext.sparkContext.parallelize( + val df = sqlContext.sparkContext.parallelize( TestData3(1, None) :: TestData3(2, Some(2)) :: Nil).toDF() df.registerTempTable("testData3") @@ -71,14 +71,14 @@ private[sql] trait SQLTestData { self => } protected lazy val negativeData: DataFrame = { - val df = _sqlContext.sparkContext.parallelize( + val df = sqlContext.sparkContext.parallelize( (1 to 100).map(i => TestData(-i, (-i).toString))).toDF() df.registerTempTable("negativeData") df } protected lazy val largeAndSmallInts: DataFrame = { - val df = _sqlContext.sparkContext.parallelize( + val df = sqlContext.sparkContext.parallelize( LargeAndSmallInts(2147483644, 1) :: LargeAndSmallInts(1, 2) :: LargeAndSmallInts(2147483645, 1) :: @@ -90,7 +90,7 @@ private[sql] trait SQLTestData { self => } protected lazy val decimalData: DataFrame = { - val df = _sqlContext.sparkContext.parallelize( + val df = sqlContext.sparkContext.parallelize( DecimalData(1, 1) :: DecimalData(1, 2) :: DecimalData(2, 1) :: @@ -102,7 +102,7 @@ private[sql] trait SQLTestData { self => } protected lazy val binaryData: DataFrame = { - val df = _sqlContext.sparkContext.parallelize( + val df = sqlContext.sparkContext.parallelize( BinaryData("12".getBytes, 1) :: BinaryData("22".getBytes, 5) :: BinaryData("122".getBytes, 3) :: @@ -113,7 +113,7 @@ private[sql] trait SQLTestData { self => } protected lazy val upperCaseData: DataFrame = { - val df = _sqlContext.sparkContext.parallelize( + val df = sqlContext.sparkContext.parallelize( UpperCaseData(1, "A") :: UpperCaseData(2, "B") :: UpperCaseData(3, "C") :: @@ -125,7 +125,7 @@ private[sql] trait SQLTestData { self => } protected lazy val lowerCaseData: DataFrame = { - val df = _sqlContext.sparkContext.parallelize( + val df = sqlContext.sparkContext.parallelize( LowerCaseData(1, "a") :: LowerCaseData(2, "b") :: LowerCaseData(3, "c") :: @@ -135,7 +135,7 @@ private[sql] trait SQLTestData { self => } protected lazy val arrayData: RDD[ArrayData] = { - val rdd = _sqlContext.sparkContext.parallelize( + val rdd = sqlContext.sparkContext.parallelize( ArrayData(Seq(1, 2, 3), Seq(Seq(1, 2, 3))) :: ArrayData(Seq(2, 3, 4), Seq(Seq(2, 3, 4))) :: Nil) rdd.toDF().registerTempTable("arrayData") @@ -143,7 +143,7 @@ private[sql] trait SQLTestData { self => } protected lazy val mapData: RDD[MapData] = { - val rdd = _sqlContext.sparkContext.parallelize( + val rdd = sqlContext.sparkContext.parallelize( MapData(Map(1 -> "a1", 2 -> "b1", 3 -> "c1", 4 -> "d1", 5 -> "e1")) :: MapData(Map(1 -> "a2", 2 -> "b2", 3 -> "c2", 4 -> "d2")) :: MapData(Map(1 -> "a3", 2 -> "b3", 3 -> "c3")) :: @@ -154,13 +154,13 @@ private[sql] trait SQLTestData { self => } protected lazy val repeatedData: RDD[StringData] = { - val rdd = _sqlContext.sparkContext.parallelize(List.fill(2)(StringData("test"))) + val rdd = sqlContext.sparkContext.parallelize(List.fill(2)(StringData("test"))) rdd.toDF().registerTempTable("repeatedData") rdd } protected lazy val nullableRepeatedData: RDD[StringData] = { - val rdd = _sqlContext.sparkContext.parallelize( + val rdd = sqlContext.sparkContext.parallelize( List.fill(2)(StringData(null)) ++ List.fill(2)(StringData("test"))) rdd.toDF().registerTempTable("nullableRepeatedData") @@ -168,7 +168,7 @@ private[sql] trait SQLTestData { self => } protected lazy val nullInts: DataFrame = { - val df = _sqlContext.sparkContext.parallelize( + val df = sqlContext.sparkContext.parallelize( NullInts(1) :: NullInts(2) :: NullInts(3) :: @@ -178,7 +178,7 @@ private[sql] trait SQLTestData { self => } protected lazy val allNulls: DataFrame = { - val df = _sqlContext.sparkContext.parallelize( + val df = sqlContext.sparkContext.parallelize( NullInts(null) :: NullInts(null) :: NullInts(null) :: @@ -188,7 +188,7 @@ private[sql] trait SQLTestData { self => } protected lazy val nullStrings: DataFrame = { - val df = _sqlContext.sparkContext.parallelize( + val df = sqlContext.sparkContext.parallelize( NullStrings(1, "abc") :: NullStrings(2, "ABC") :: NullStrings(3, null) :: Nil).toDF() @@ -197,13 +197,13 @@ private[sql] trait SQLTestData { self => } protected lazy val tableName: DataFrame = { - val df = _sqlContext.sparkContext.parallelize(TableName("test") :: Nil).toDF() + val df = sqlContext.sparkContext.parallelize(TableName("test") :: Nil).toDF() df.registerTempTable("tableName") df } protected lazy val unparsedStrings: RDD[String] = { - _sqlContext.sparkContext.parallelize( + sqlContext.sparkContext.parallelize( "1, A1, true, null" :: "2, B2, false, null" :: "3, C3, true, null" :: @@ -212,13 +212,13 @@ private[sql] trait SQLTestData { self => // An RDD with 4 elements and 8 partitions protected lazy val withEmptyParts: RDD[IntField] = { - val rdd = _sqlContext.sparkContext.parallelize((1 to 4).map(IntField), 8) + val rdd = sqlContext.sparkContext.parallelize((1 to 4).map(IntField), 8) rdd.toDF().registerTempTable("withEmptyParts") rdd } protected lazy val person: DataFrame = { - val df = _sqlContext.sparkContext.parallelize( + val df = sqlContext.sparkContext.parallelize( Person(0, "mike", 30) :: Person(1, "jim", 20) :: Nil).toDF() df.registerTempTable("person") @@ -226,7 +226,7 @@ private[sql] trait SQLTestData { self => } protected lazy val salary: DataFrame = { - val df = _sqlContext.sparkContext.parallelize( + val df = sqlContext.sparkContext.parallelize( Salary(0, 2000.0) :: Salary(1, 1000.0) :: Nil).toDF() df.registerTempTable("salary") @@ -234,7 +234,7 @@ private[sql] trait SQLTestData { self => } protected lazy val complexData: DataFrame = { - val df = _sqlContext.sparkContext.parallelize( + val df = sqlContext.sparkContext.parallelize( ComplexData(Map("1" -> 1), TestData(1, "1"), Seq(1, 1, 1), true) :: ComplexData(Map("2" -> 2), TestData(2, "2"), Seq(2, 2, 2), false) :: Nil).toDF() @@ -246,7 +246,7 @@ private[sql] trait SQLTestData { self => * Initialize all test data such that all temp tables are properly registered. */ def loadTestData(): Unit = { - assert(_sqlContext != null, "attempted to initialize test data before SQLContext.") + assert(sqlContext != null, "attempted to initialize test data before SQLContext.") emptyTestData testData testData2 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index dc08306ad9..9214569f18 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -27,7 +27,7 @@ import org.apache.hadoop.conf.Configuration import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.{DataFrame, Row, SQLContext, SQLImplicits} +import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ import org.apache.spark.util.Utils @@ -47,13 +47,13 @@ private[sql] trait SQLTestUtils with BeforeAndAfterAll with SQLTestData { self => - protected def _sqlContext: SQLContext + protected def sparkContext = sqlContext.sparkContext // Whether to materialize all test data before the first test is run private var loadTestDataBeforeTests = false // Shorthand for running a query using our SQLContext - protected lazy val sql = _sqlContext.sql _ + protected lazy val sql = sqlContext.sql _ /** * A helper object for importing SQL implicits. @@ -63,7 +63,14 @@ private[sql] trait SQLTestUtils * but the implicits import is needed in the constructor. */ protected object testImplicits extends SQLImplicits { - protected override def _sqlContext: SQLContext = self._sqlContext + protected override def _sqlContext: SQLContext = self.sqlContext + + // This must live here to preserve binary compatibility with Spark < 1.5. + implicit class StringToColumn(val sc: StringContext) { + def $(args: Any*): ColumnName = { + new ColumnName(sc.s(args: _*)) + } + } } /** @@ -84,8 +91,8 @@ private[sql] trait SQLTestUtils /** * The Hadoop configuration used by the active [[SQLContext]]. */ - protected def configuration: Configuration = { - _sqlContext.sparkContext.hadoopConfiguration + protected def hadoopConfiguration: Configuration = { + sparkContext.hadoopConfiguration } /** @@ -96,12 +103,12 @@ private[sql] trait SQLTestUtils */ protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { val (keys, values) = pairs.unzip - val currentValues = keys.map(key => Try(_sqlContext.conf.getConfString(key)).toOption) - (keys, values).zipped.foreach(_sqlContext.conf.setConfString) + val currentValues = keys.map(key => Try(sqlContext.conf.getConfString(key)).toOption) + (keys, values).zipped.foreach(sqlContext.conf.setConfString) try f finally { keys.zip(currentValues).foreach { - case (key, Some(value)) => _sqlContext.conf.setConfString(key, value) - case (key, None) => _sqlContext.conf.unsetConf(key) + case (key, Some(value)) => sqlContext.conf.setConfString(key, value) + case (key, None) => sqlContext.conf.unsetConf(key) } } } @@ -133,7 +140,7 @@ private[sql] trait SQLTestUtils * Drops temporary table `tableName` after calling `f`. */ protected def withTempTable(tableNames: String*)(f: => Unit): Unit = { - try f finally tableNames.foreach(_sqlContext.dropTempTable) + try f finally tableNames.foreach(sqlContext.dropTempTable) } /** @@ -142,7 +149,7 @@ private[sql] trait SQLTestUtils protected def withTable(tableNames: String*)(f: => Unit): Unit = { try f finally { tableNames.foreach { name => - _sqlContext.sql(s"DROP TABLE IF EXISTS $name") + sqlContext.sql(s"DROP TABLE IF EXISTS $name") } } } @@ -155,12 +162,12 @@ private[sql] trait SQLTestUtils val dbName = s"db_${UUID.randomUUID().toString.replace('-', '_')}" try { - _sqlContext.sql(s"CREATE DATABASE $dbName") + sqlContext.sql(s"CREATE DATABASE $dbName") } catch { case cause: Throwable => fail("Failed to create temporary database", cause) } - try f(dbName) finally _sqlContext.sql(s"DROP DATABASE $dbName CASCADE") + try f(dbName) finally sqlContext.sql(s"DROP DATABASE $dbName CASCADE") } /** @@ -168,8 +175,8 @@ private[sql] trait SQLTestUtils * `f` returns. */ protected def activateDatabase(db: String)(f: => Unit): Unit = { - _sqlContext.sql(s"USE $db") - try f finally _sqlContext.sql(s"USE default") + sqlContext.sql(s"USE $db") + try f finally sqlContext.sql(s"USE default") } /** @@ -177,7 +184,7 @@ private[sql] trait SQLTestUtils * way to construct [[DataFrame]] directly out of local data without relying on implicits. */ protected implicit def logicalPlanToSparkQuery(plan: LogicalPlan): DataFrame = { - DataFrame(_sqlContext, plan) + DataFrame(sqlContext, plan) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala index d23c6a0732..963d10eed6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.test -import org.apache.spark.sql.{ColumnName, SQLContext} +import org.apache.spark.sql.SQLContext /** @@ -36,9 +36,7 @@ trait SharedSQLContext extends SQLTestUtils { /** * The [[TestSQLContext]] to use for all tests in this suite. */ - protected def ctx: TestSQLContext = _ctx - protected def sqlContext: TestSQLContext = _ctx - protected override def _sqlContext: SQLContext = _ctx + protected def sqlContext: SQLContext = _ctx /** * Initialize the [[TestSQLContext]]. @@ -64,15 +62,4 @@ trait SharedSQLContext extends SQLTestUtils { super.afterAll() } } - - /** - * Converts $"col name" into an [[Column]]. - * @since 1.3.0 - */ - // This must be duplicated here to preserve binary compatibility with Spark < 1.5. - implicit class StringToColumn(val sc: StringContext) { - def $(args: Any*): ColumnName = { - new ColumnName(sc.s(args: _*)) - } - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala index 92ef2f7d74..d99d191ebe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -47,6 +47,6 @@ private[sql] class TestSQLContext(sc: SparkContext) extends SQLContext(sc) { sel } private object testData extends SQLTestData { - protected override def _sqlContext: SQLContext = self + protected override def sqlContext: SQLContext = self } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 57fea5d8db..77f43f9270 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -29,7 +29,7 @@ import org.apache.hadoop.hive.ql.exec.FunctionRegistry import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe -import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.{SQLContext, SQLConf} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.CacheTableCommand @@ -51,6 +51,11 @@ object TestHive // SPARK-8910 .set("spark.ui.enabled", "false"))) +trait TestHiveSingleton { + protected val sqlContext: SQLContext = TestHive + protected val hiveContext: TestHiveContext = TestHive +} + /** * A locally running test instance of Spark's Hive execution engine. * diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index 39d315aaea..9adb3780a2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -19,14 +19,14 @@ package org.apache.spark.sql.hive import java.io.File -import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation} -import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.hive.test.TestHive._ -import org.apache.spark.sql.{SaveMode, AnalysisException, DataFrame, QueryTest} +import org.apache.spark.sql.columnar.InMemoryColumnarTableScan +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.{AnalysisException, QueryTest, SaveMode} import org.apache.spark.storage.RDDBlockId import org.apache.spark.util.Utils -class CachedTableSuite extends QueryTest { +class CachedTableSuite extends QueryTest with TestHiveSingleton { + import hiveContext._ def rddIdOf(tableName: String): Int = { val executedPlan = table(tableName).queryExecution.executedPlan @@ -95,18 +95,18 @@ class CachedTableSuite extends QueryTest { test("correct error on uncache of non-cached table") { intercept[IllegalArgumentException] { - TestHive.uncacheTable("src") + hiveContext.uncacheTable("src") } } test("'CACHE TABLE' and 'UNCACHE TABLE' HiveQL statement") { - TestHive.sql("CACHE TABLE src") + sql("CACHE TABLE src") assertCached(table("src")) - assert(TestHive.isCached("src"), "Table 'src' should be cached") + assert(hiveContext.isCached("src"), "Table 'src' should be cached") - TestHive.sql("UNCACHE TABLE src") + sql("UNCACHE TABLE src") assertCached(table("src"), 0) - assert(!TestHive.isCached("src"), "Table 'src' should not be cached") + assert(!hiveContext.isCached("src"), "Table 'src' should not be cached") } test("CACHE TABLE tableName AS SELECT * FROM anotherTable") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala index 30f5313d2b..cf73783693 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala @@ -22,12 +22,12 @@ import scala.util.Try import org.scalatest.BeforeAndAfter import org.apache.spark.sql.catalyst.util.quietly -import org.apache.spark.sql.hive.test.TestHive._ -import org.apache.spark.sql.hive.test.TestHive.implicits._ +import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.{AnalysisException, QueryTest} -class ErrorPositionSuite extends QueryTest with BeforeAndAfter { +class ErrorPositionSuite extends QueryTest with TestHiveSingleton with BeforeAndAfter { + import hiveContext.implicits._ before { Seq((1, 1, 1)).toDF("a", "a", "b").registerTempTable("dupAttributes") @@ -122,7 +122,7 @@ class ErrorPositionSuite extends QueryTest with BeforeAndAfter { test(name) { val error = intercept[AnalysisException] { - quietly(sql(query)) + quietly(hiveContext.sql(query)) } assert(!error.getMessage.contains("Seq(")) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala index fb10f8583d..2e5cae415e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala @@ -19,24 +19,25 @@ package org.apache.spark.sql.hive import org.apache.spark.sql.{DataFrame, QueryTest} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.hive.test.TestHive._ -import org.apache.spark.sql.hive.test.TestHive.implicits._ +import org.apache.spark.sql.hive.test.TestHiveSingleton import org.scalatest.BeforeAndAfterAll // TODO ideally we should put the test suite into the package `sql`, as // `hive` package is optional in compiling, however, `SQLContext.sql` doesn't // support the `cube` or `rollup` yet. -class HiveDataFrameAnalyticsSuite extends QueryTest with BeforeAndAfterAll { +class HiveDataFrameAnalyticsSuite extends QueryTest with TestHiveSingleton with BeforeAndAfterAll { + import hiveContext.implicits._ + import hiveContext.sql + private var testData: DataFrame = _ override def beforeAll() { testData = Seq((1, 2), (2, 4)).toDF("a", "b") - TestHive.registerDataFrameAsTable(testData, "mytable") + hiveContext.registerDataFrameAsTable(testData, "mytable") } override def afterAll(): Unit = { - TestHive.dropTempTable("mytable") + hiveContext.dropTempTable("mytable") } test("rollup") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameJoinSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameJoinSuite.scala index 52e782768c..f621367eb5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameJoinSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameJoinSuite.scala @@ -18,10 +18,10 @@ package org.apache.spark.sql.hive import org.apache.spark.sql.{Row, QueryTest} -import org.apache.spark.sql.hive.test.TestHive.implicits._ +import org.apache.spark.sql.hive.test.TestHiveSingleton - -class HiveDataFrameJoinSuite extends QueryTest { +class HiveDataFrameJoinSuite extends QueryTest with TestHiveSingleton { + import hiveContext.implicits._ // We should move this into SQL package if we make case sensitivity configurable in SQL. test("join - self join auto resolve ambiguity with case insensitivity") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala index c177cbdd99..2c98f1c3cc 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala @@ -20,10 +20,11 @@ package org.apache.spark.sql.hive import org.apache.spark.sql.{Row, QueryTest} import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions._ -import org.apache.spark.sql.hive.test.TestHive._ -import org.apache.spark.sql.hive.test.TestHive.implicits._ +import org.apache.spark.sql.hive.test.TestHiveSingleton -class HiveDataFrameWindowSuite extends QueryTest { +class HiveDataFrameWindowSuite extends QueryTest with TestHiveSingleton { + import hiveContext.implicits._ + import hiveContext.sql test("reuse window partitionBy") { val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index 574624d501..107457f79e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -19,18 +19,15 @@ package org.apache.spark.sql.hive import java.io.File +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.{QueryTest, Row, SaveMode} import org.apache.spark.sql.hive.client.{ExternalTable, ManagedTable} -import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.hive.test.TestHive._ -import org.apache.spark.sql.hive.test.TestHive.implicits._ -import org.apache.spark.sql.sources.DataSourceTest +import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.{ExamplePointUDT, SQLTestUtils} import org.apache.spark.sql.types.{DecimalType, StringType, StructType} -import org.apache.spark.sql.{Row, SaveMode, SQLContext} -import org.apache.spark.{Logging, SparkFunSuite} - -class HiveMetastoreCatalogSuite extends SparkFunSuite with Logging { +class HiveMetastoreCatalogSuite extends SparkFunSuite with TestHiveSingleton { + import hiveContext.implicits._ test("struct field should accept underscore in sub-column name") { val hiveTypeStr = "struct" @@ -46,14 +43,15 @@ class HiveMetastoreCatalogSuite extends SparkFunSuite with Logging { } test("duplicated metastore relations") { - val df = sql("SELECT * FROM src") + val df = hiveContext.sql("SELECT * FROM src") logInfo(df.queryExecution.toString) df.as('a).join(df.as('b), $"a.key" === $"b.key") } } -class DataSourceWithHiveMetastoreCatalogSuite extends DataSourceTest with SQLTestUtils { - override def _sqlContext: SQLContext = TestHive +class DataSourceWithHiveMetastoreCatalogSuite + extends QueryTest with SQLTestUtils with TestHiveSingleton { + import hiveContext._ import testImplicits._ private val testDF = range(1, 3).select( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala index fe0db5228d..5596ec6882 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala @@ -17,15 +17,13 @@ package org.apache.spark.sql.hive -import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.execution.datasources.parquet.ParquetTest -import org.apache.spark.sql.{QueryTest, Row, SQLContext} +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.{QueryTest, Row} case class Cases(lower: String, UPPER: String) -class HiveParquetSuite extends QueryTest with ParquetTest { - private val ctx = TestHive - override def _sqlContext: SQLContext = ctx +class HiveParquetSuite extends QueryTest with ParquetTest with TestHiveSingleton { test("Case insensitive attribute names") { withParquetTable((1 to 4).map(i => Cases(i.toString, i.toString)), "cases") { @@ -53,7 +51,7 @@ class HiveParquetSuite extends QueryTest with ParquetTest { test("Converting Hive to Parquet Table via saveAsParquetFile") { withTempPath { dir => sql("SELECT * FROM src").write.parquet(dir.getCanonicalPath) - ctx.read.parquet(dir.getCanonicalPath).registerTempTable("p") + hiveContext.read.parquet(dir.getCanonicalPath).registerTempTable("p") withTempTable("p") { checkAnswer( sql("SELECT * FROM src ORDER BY key"), @@ -66,7 +64,7 @@ class HiveParquetSuite extends QueryTest with ParquetTest { withParquetTable((1 to 10).map(i => (i, s"val_$i")), "t") { withTempPath { file => sql("SELECT * FROM t LIMIT 1").write.parquet(file.getCanonicalPath) - ctx.read.parquet(file.getCanonicalPath).registerTempTable("p") + hiveContext.read.parquet(file.getCanonicalPath).registerTempTable("p") withTempTable("p") { // let's do three overwrites for good measure sql("INSERT OVERWRITE TABLE p SELECT * FROM t") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index dc2d85f486..84f3db44ec 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -29,7 +29,7 @@ import org.scalatest.exceptions.TestFailedDueToTimeoutException import org.scalatest.time.SpanSugar._ import org.apache.spark._ -import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.{SQLContext, QueryTest} import org.apache.spark.sql.hive.test.{TestHive, TestHiveContext} import org.apache.spark.sql.test.ProcessTestUtils.ProcessOutputCapturer import org.apache.spark.sql.types.DecimalType @@ -272,7 +272,11 @@ object SparkSQLConfTest extends Logging { } } -object SPARK_9757 extends QueryTest with Logging { +object SPARK_9757 extends QueryTest { + import org.apache.spark.sql.functions._ + + protected var sqlContext: SQLContext = _ + def main(args: Array[String]): Unit = { Utils.configTestLog4j("INFO") @@ -282,10 +286,9 @@ object SPARK_9757 extends QueryTest with Logging { .set("spark.sql.hive.metastore.jars", "maven")) val hiveContext = new TestHiveContext(sparkContext) + sqlContext = hiveContext import hiveContext.implicits._ - import org.apache.spark.sql.functions._ - val dir = Utils.createTempDir() dir.delete() diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index d33e81227d..80a61f82fd 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -24,28 +24,25 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.sql.execution.QueryExecutionException import org.apache.spark.sql.{QueryTest, _} -import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -/* Implicits */ -import org.apache.spark.sql.hive.test.TestHive._ - case class TestData(key: Int, value: String) case class ThreeCloumntable(key: Int, value: String, key1: String) -class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { - import org.apache.spark.sql.hive.test.TestHive.implicits._ - +class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with BeforeAndAfter { + import hiveContext.implicits._ + import hiveContext.sql - val testData = TestHive.sparkContext.parallelize( + val testData = hiveContext.sparkContext.parallelize( (1 to 100).map(i => TestData(i, i.toString))).toDF() before { // Since every we are doing tests for DDL statements, // it is better to reset before every test. - TestHive.reset() + hiveContext.reset() // Register the testData, which will be used in every test. testData.registerTempTable("testData") } @@ -96,9 +93,9 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { test("SPARK-4052: scala.collection.Map as value type of MapType") { val schema = StructType(StructField("m", MapType(StringType, StringType), true) :: Nil) - val rowRDD = TestHive.sparkContext.parallelize( + val rowRDD = hiveContext.sparkContext.parallelize( (1 to 100).map(i => Row(scala.collection.mutable.HashMap(s"key$i" -> s"value$i")))) - val df = TestHive.createDataFrame(rowRDD, schema) + val df = hiveContext.createDataFrame(rowRDD, schema) df.registerTempTable("tableWithMapValue") sql("CREATE TABLE hiveTableWithMapValue(m MAP )") sql("INSERT OVERWRITE TABLE hiveTableWithMapValue SELECT m FROM tableWithMapValue") @@ -169,8 +166,8 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { test("Insert ArrayType.containsNull == false") { val schema = StructType(Seq( StructField("a", ArrayType(StringType, containsNull = false)))) - val rowRDD = TestHive.sparkContext.parallelize((1 to 100).map(i => Row(Seq(s"value$i")))) - val df = TestHive.createDataFrame(rowRDD, schema) + val rowRDD = hiveContext.sparkContext.parallelize((1 to 100).map(i => Row(Seq(s"value$i")))) + val df = hiveContext.createDataFrame(rowRDD, schema) df.registerTempTable("tableWithArrayValue") sql("CREATE TABLE hiveTableWithArrayValue(a Array )") sql("INSERT OVERWRITE TABLE hiveTableWithArrayValue SELECT a FROM tableWithArrayValue") @@ -185,9 +182,9 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { test("Insert MapType.valueContainsNull == false") { val schema = StructType(Seq( StructField("m", MapType(StringType, StringType, valueContainsNull = false)))) - val rowRDD = TestHive.sparkContext.parallelize( + val rowRDD = hiveContext.sparkContext.parallelize( (1 to 100).map(i => Row(Map(s"key$i" -> s"value$i")))) - val df = TestHive.createDataFrame(rowRDD, schema) + val df = hiveContext.createDataFrame(rowRDD, schema) df.registerTempTable("tableWithMapValue") sql("CREATE TABLE hiveTableWithMapValue(m Map )") sql("INSERT OVERWRITE TABLE hiveTableWithMapValue SELECT m FROM tableWithMapValue") @@ -202,9 +199,9 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { test("Insert StructType.fields.exists(_.nullable == false)") { val schema = StructType(Seq( StructField("s", StructType(Seq(StructField("f", StringType, nullable = false)))))) - val rowRDD = TestHive.sparkContext.parallelize( + val rowRDD = hiveContext.sparkContext.parallelize( (1 to 100).map(i => Row(Row(s"value$i")))) - val df = TestHive.createDataFrame(rowRDD, schema) + val df = hiveContext.createDataFrame(rowRDD, schema) df.registerTempTable("tableWithStructValue") sql("CREATE TABLE hiveTableWithStructValue(s Struct )") sql("INSERT OVERWRITE TABLE hiveTableWithStructValue SELECT s FROM tableWithStructValue") @@ -217,11 +214,11 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { } test("SPARK-5498:partition schema does not match table schema") { - val testData = TestHive.sparkContext.parallelize( + val testData = hiveContext.sparkContext.parallelize( (1 to 10).map(i => TestData(i, i.toString))).toDF() testData.registerTempTable("testData") - val testDatawithNull = TestHive.sparkContext.parallelize( + val testDatawithNull = hiveContext.sparkContext.parallelize( (1 to 10).map(i => ThreeCloumntable(i, i.toString, null))).toDF() val tmpDir = Utils.createTempDir() diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala index d3388a9429..579631df77 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala @@ -19,17 +19,15 @@ package org.apache.spark.sql.hive import org.scalatest.BeforeAndAfterAll -import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.QueryTest import org.apache.spark.sql.Row -class ListTablesSuite extends QueryTest with BeforeAndAfterAll { +class ListTablesSuite extends QueryTest with TestHiveSingleton with BeforeAndAfterAll { + import hiveContext._ + import hiveContext.implicits._ - import org.apache.spark.sql.hive.test.TestHive.implicits._ - - val df = - sparkContext.parallelize((1 to 10).map(i => (i, s"str$i"))).toDF("key", "value") + val df = sparkContext.parallelize((1 to 10).map(i => (i, s"str$i"))).toDF("key", "value") override def beforeAll(): Unit = { // The catalog in HiveContext is a case insensitive one. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index 20a50586d5..bf0db08490 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -22,15 +22,11 @@ import java.io.{IOException, File} import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.fs.Path -import org.scalatest.BeforeAndAfterAll -import org.apache.spark.Logging import org.apache.spark.sql._ import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.hive.client.{HiveTable, ManagedTable} -import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.hive.test.TestHive._ -import org.apache.spark.sql.hive.test.TestHive.implicits._ +import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ @@ -39,10 +35,9 @@ import org.apache.spark.util.Utils /** * Tests for persisting tables created though the data sources API into the metastore. */ -class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with BeforeAndAfterAll - with Logging { - override def _sqlContext: SQLContext = TestHive - private val sqlContext = _sqlContext +class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { + import hiveContext._ + import hiveContext.implicits._ var jsonFilePath: String = _ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala index 997c667ec0..f16c257ab5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala @@ -17,20 +17,16 @@ package org.apache.spark.sql.hive -import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils -import org.apache.spark.sql.{AnalysisException, QueryTest, SQLContext, SaveMode} +import org.apache.spark.sql.{AnalysisException, QueryTest, SaveMode} -class MultiDatabaseSuite extends QueryTest with SQLTestUtils { - override val _sqlContext: HiveContext = TestHive - private val sqlContext = _sqlContext - - private val df = sqlContext.range(10).coalesce(1) +class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { + private lazy val df = sqlContext.range(10).coalesce(1) private def checkTablePath(dbName: String, tableName: String): Unit = { - // val hiveContext = sqlContext.asInstanceOf[HiveContext] - val metastoreTable = sqlContext.catalog.client.getTable(dbName, tableName) - val expectedPath = sqlContext.catalog.client.getDatabase(dbName).location + "/" + tableName + val metastoreTable = hiveContext.catalog.client.getTable(dbName, tableName) + val expectedPath = hiveContext.catalog.client.getDatabase(dbName).location + "/" + tableName assert(metastoreTable.serdeProperties("path") === expectedPath) } @@ -220,7 +216,7 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils { df.write.parquet(s"$path/p=2") sql("ALTER TABLE t ADD PARTITION (p=2)") - sqlContext.refreshTable("t") + hiveContext.refreshTable("t") checkAnswer( sqlContext.table("t"), df.withColumn("p", lit(1)).unionAll(df.withColumn("p", lit(2)))) @@ -252,7 +248,7 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils { df.write.parquet(s"$path/p=2") sql(s"ALTER TABLE $db.t ADD PARTITION (p=2)") - sqlContext.refreshTable(s"$db.t") + hiveContext.refreshTable(s"$db.t") checkAnswer( sqlContext.table(s"$db.t"), df.withColumn("p", lit(1)).unionAll(df.withColumn("p", lit(2)))) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala index 91d7a48208..49aab85cf1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala @@ -18,38 +18,20 @@ package org.apache.spark.sql.hive import java.sql.Timestamp -import java.util.{Locale, TimeZone} import org.apache.hadoop.hive.conf.HiveConf -import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql.execution.datasources.parquet.ParquetCompatibilityTest -import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.{Row, SQLConf, SQLContext} - -class ParquetHiveCompatibilitySuite extends ParquetCompatibilityTest with BeforeAndAfterAll { - override def _sqlContext: SQLContext = TestHive - private val sqlContext = _sqlContext +import org.apache.spark.sql.{Row, SQLConf} +import org.apache.spark.sql.hive.test.TestHiveSingleton +class ParquetHiveCompatibilitySuite extends ParquetCompatibilityTest with TestHiveSingleton { /** * Set the staging directory (and hence path to ignore Parquet files under) * to that set by [[HiveConf.ConfVars.STAGINGDIR]]. */ private val stagingDir = new HiveConf().getVar(HiveConf.ConfVars.STAGINGDIR) - private val originalTimeZone = TimeZone.getDefault - private val originalLocale = Locale.getDefault - - protected override def beforeAll(): Unit = { - TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) - Locale.setDefault(Locale.US) - } - - override protected def afterAll(): Unit = { - TimeZone.setDefault(originalTimeZone) - Locale.setDefault(originalLocale) - } - override protected def logParquetSchema(path: String): Unit = { val schema = readParquetSchema(path, { path => !path.getName.startsWith("_") && !path.getName.startsWith(stagingDir) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala index 1cc8a93e83..f542a5a025 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala @@ -18,22 +18,18 @@ package org.apache.spark.sql.hive import com.google.common.io.Files -import org.apache.spark.sql.test.SQLTestUtils -import org.apache.spark.sql.{QueryTest, _} import org.apache.spark.util.Utils +import org.apache.spark.sql.{QueryTest, _} +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.test.SQLTestUtils +class QueryPartitionSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { + import hiveContext.implicits._ -class QueryPartitionSuite extends QueryTest with SQLTestUtils { - - private lazy val ctx = org.apache.spark.sql.hive.test.TestHive - import ctx.implicits._ - - protected def _sqlContext = ctx - - test("SPARK-5068: query data when path doesn't exist"){ + test("SPARK-5068: query data when path doesn't exist") { withSQLConf((SQLConf.HIVE_VERIFY_PARTITION_PATH.key, "true")) { - val testData = ctx.sparkContext.parallelize( + val testData = sparkContext.parallelize( (1 to 10).map(i => TestData(i, i.toString))).toDF() testData.registerTempTable("testData") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index e4fec7e2c8..6a692d6fce 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -17,24 +17,15 @@ package org.apache.spark.sql.hive -import org.scalatest.BeforeAndAfterAll - import scala.reflect.ClassTag import org.apache.spark.sql.{Row, SQLConf, QueryTest} import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.hive.execution._ +import org.apache.spark.sql.hive.test.TestHiveSingleton -class StatisticsSuite extends QueryTest with BeforeAndAfterAll { - - private lazy val ctx: HiveContext = { - val ctx = org.apache.spark.sql.hive.test.TestHive - ctx.reset() - ctx.cacheTables = false - ctx - } - - import ctx.sql +class StatisticsSuite extends QueryTest with TestHiveSingleton { + import hiveContext.sql test("parse analyze commands") { def assertAnalyzeCommand(analyzeCommand: String, c: Class[_]) { @@ -54,9 +45,6 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { } } - // Ensure session state is initialized. - ctx.parseSql("use default") - assertAnalyzeCommand( "ANALYZE TABLE Table1 COMPUTE STATISTICS", classOf[HiveNativeCommand]) @@ -80,7 +68,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { test("analyze MetastoreRelations") { def queryTotalSize(tableName: String): BigInt = - ctx.catalog.lookupRelation(Seq(tableName)).statistics.sizeInBytes + hiveContext.catalog.lookupRelation(Seq(tableName)).statistics.sizeInBytes // Non-partitioned table sql("CREATE TABLE analyzeTable (key STRING, value STRING)").collect() @@ -114,7 +102,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { |SELECT * FROM src """.stripMargin).collect() - assert(queryTotalSize("analyzeTable_part") === ctx.conf.defaultSizeInBytes) + assert(queryTotalSize("analyzeTable_part") === hiveContext.conf.defaultSizeInBytes) sql("ANALYZE TABLE analyzeTable_part COMPUTE STATISTICS noscan") @@ -125,9 +113,9 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { // Try to analyze a temp table sql("""SELECT * FROM src""").registerTempTable("tempTable") intercept[UnsupportedOperationException] { - ctx.analyze("tempTable") + hiveContext.analyze("tempTable") } - ctx.catalog.unregisterTable(Seq("tempTable")) + hiveContext.catalog.unregisterTable(Seq("tempTable")) } test("estimates the size of a test MetastoreRelation") { @@ -155,8 +143,8 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { val sizes = df.queryExecution.analyzed.collect { case r if ct.runtimeClass.isAssignableFrom(r.getClass) => r.statistics.sizeInBytes } - assert(sizes.size === 2 && sizes(0) <= ctx.conf.autoBroadcastJoinThreshold - && sizes(1) <= ctx.conf.autoBroadcastJoinThreshold, + assert(sizes.size === 2 && sizes(0) <= hiveContext.conf.autoBroadcastJoinThreshold + && sizes(1) <= hiveContext.conf.autoBroadcastJoinThreshold, s"query should contain two relations, each of which has size smaller than autoConvertSize") // Using `sparkPlan` because for relevant patterns in HashJoin to be @@ -167,8 +155,8 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { checkAnswer(df, expectedAnswer) // check correctness of output - ctx.conf.settings.synchronized { - val tmp = ctx.conf.autoBroadcastJoinThreshold + hiveContext.conf.settings.synchronized { + val tmp = hiveContext.conf.autoBroadcastJoinThreshold sql(s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=-1""") df = sql(query) @@ -211,8 +199,8 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { .isAssignableFrom(r.getClass) => r.statistics.sizeInBytes } - assert(sizes.size === 2 && sizes(1) <= ctx.conf.autoBroadcastJoinThreshold - && sizes(0) <= ctx.conf.autoBroadcastJoinThreshold, + assert(sizes.size === 2 && sizes(1) <= hiveContext.conf.autoBroadcastJoinThreshold + && sizes(0) <= hiveContext.conf.autoBroadcastJoinThreshold, s"query should contain two relations, each of which has size smaller than autoConvertSize") // Using `sparkPlan` because for relevant patterns in HashJoin to be @@ -225,8 +213,8 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { checkAnswer(df, answer) // check correctness of output - ctx.conf.settings.synchronized { - val tmp = ctx.conf.autoBroadcastJoinThreshold + hiveContext.conf.settings.synchronized { + val tmp = hiveContext.conf.autoBroadcastJoinThreshold sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=-1") df = sql(leftSemiJoinQuery) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala index 7ee1c8d13a..3ab4576811 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala @@ -18,18 +18,18 @@ package org.apache.spark.sql.hive import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.hive.test.TestHiveSingleton case class FunctionResult(f1: String, f2: String) -class UDFSuite extends QueryTest { - private lazy val ctx = org.apache.spark.sql.hive.test.TestHive +class UDFSuite extends QueryTest with TestHiveSingleton { test("UDF case insensitive") { - ctx.udf.register("random0", () => { Math.random() }) - ctx.udf.register("RANDOM1", () => { Math.random() }) - ctx.udf.register("strlenScala", (_: String).length + (_: Int)) - assert(ctx.sql("SELECT RANDOM0() FROM src LIMIT 1").head().getDouble(0) >= 0.0) - assert(ctx.sql("SELECT RANDOm1() FROM src LIMIT 1").head().getDouble(0) >= 0.0) - assert(ctx.sql("SELECT strlenscala('test', 1) FROM src LIMIT 1").head().getInt(0) === 5) + hiveContext.udf.register("random0", () => { Math.random() }) + hiveContext.udf.register("RANDOM1", () => { Math.random() }) + hiveContext.udf.register("strlenScala", (_: String).length + (_: Int)) + assert(hiveContext.sql("SELECT RANDOM0() FROM src LIMIT 1").head().getDouble(0) >= 0.0) + assert(hiveContext.sql("SELECT RANDOm1() FROM src LIMIT 1").head().getDouble(0) >= 0.0) + assert(hiveContext.sql("SELECT strlenscala('test', 1) FROM src LIMIT 1").head().getInt(0) === 5) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 4886a85948..b126ec455f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -17,19 +17,15 @@ package org.apache.spark.sql.hive.execution -import org.scalatest.BeforeAndAfterAll - import org.apache.spark.sql._ import org.apache.spark.sql.execution.aggregate -import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} import org.apache.spark.sql.hive.aggregate.{MyDoubleAvg, MyDoubleSum} +import org.apache.spark.sql.hive.test.TestHiveSingleton -abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with BeforeAndAfterAll { - override def _sqlContext: SQLContext = TestHive - protected val sqlContext = _sqlContext - import sqlContext.implicits._ +abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { + import testImplicits._ var originalUseAggregate2: Boolean = _ @@ -69,7 +65,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be data2.write.saveAsTable("agg2") val emptyDF = sqlContext.createDataFrame( - sqlContext.sparkContext.emptyRDD[Row], + sparkContext.emptyRDD[Row], StructType(StructField("key", StringType) :: StructField("value", IntegerType) :: Nil)) emptyDF.registerTempTable("emptyTable") @@ -597,7 +593,7 @@ class TungstenAggregationQueryWithControlledFallbackSuite extends AggregationQue sqlContext.conf.unsetConf("spark.sql.TungstenAggregate.testFallbackStartsAt") } - override protected def checkAnswer(actual: DataFrame, expectedAnswer: Seq[Row]): Unit = { + override protected def checkAnswer(actual: => DataFrame, expectedAnswer: Seq[Row]): Unit = { (0 to 2).foreach { fallbackStartsAt => sqlContext.setConf( "spark.sql.TungstenAggregate.testFallbackStartsAt", @@ -605,6 +601,7 @@ class TungstenAggregationQueryWithControlledFallbackSuite extends AggregationQue // Create a new df to make sure its physical operator picks up // spark.sql.TungstenAggregate.testFallbackStartsAt. + // todo: remove it? val newActual = DataFrame(sqlContext, actual.logicalPlan) QueryTest.checkAnswer(newActual, expectedAnswer) match { @@ -626,12 +623,12 @@ class TungstenAggregationQueryWithControlledFallbackSuite extends AggregationQue } // Override it to make sure we call the actually overridden checkAnswer. - override protected def checkAnswer(df: DataFrame, expectedAnswer: Row): Unit = { + override protected def checkAnswer(df: => DataFrame, expectedAnswer: Row): Unit = { checkAnswer(df, Seq(expectedAnswer)) } // Override it to make sure we call the actually overridden checkAnswer. - override protected def checkAnswer(df: DataFrame, expectedAnswer: DataFrame): Unit = { + override protected def checkAnswer(df: => DataFrame, expectedAnswer: DataFrame): Unit = { checkAnswer(df, expectedAnswer.collect()) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index 4d45249d9c..aa95ba94fa 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -23,7 +23,7 @@ import scala.util.control.NonFatal import org.scalatest.{BeforeAndAfterAll, GivenWhenThen} -import org.apache.spark.{Logging, SparkFunSuite} +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util._ @@ -42,7 +42,7 @@ import org.apache.spark.sql.hive.test.TestHive * configured using system properties. */ abstract class HiveComparisonTest - extends SparkFunSuite with BeforeAndAfterAll with GivenWhenThen with Logging { + extends SparkFunSuite with BeforeAndAfterAll with GivenWhenThen { /** * When set, any cache files that result in test failures will be deleted. Used when the test diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala index 11d7a872df..94162da4ea 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala @@ -17,17 +17,14 @@ package org.apache.spark.sql.hive.execution -import org.apache.spark.sql.{SQLContext, QueryTest} -import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.QueryTest import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.hive.test.TestHiveSingleton /** * A set of tests that validates support for Hive Explain command. */ -class HiveExplainSuite extends QueryTest with SQLTestUtils { - override def _sqlContext: SQLContext = TestHive - private val sqlContext = _sqlContext +class HiveExplainSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { test("explain extended command") { checkExistence(sql(" explain select * from src where key=123 "), true, @@ -83,7 +80,7 @@ class HiveExplainSuite extends QueryTest with SQLTestUtils { test("SPARK-6212: The EXPLAIN output of CTAS only shows the analyzed plan") { withTempTable("jt") { val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str$i"}""")) - read.json(rdd).registerTempTable("jt") + hiveContext.read.json(rdd).registerTempTable("jt") val outputs = sql( s""" |EXPLAIN EXTENDED diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveOperatorQueryableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveOperatorQueryableSuite.scala index efbef68cd4..0d4c7f86b3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveOperatorQueryableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveOperatorQueryableSuite.scala @@ -18,14 +18,16 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.{Row, QueryTest} -import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.hive.test.{TestHive, TestHiveSingleton} /** * A set of tests that validates commands can also be queried by like a table */ -class HiveOperatorQueryableSuite extends QueryTest { +class HiveOperatorQueryableSuite extends QueryTest with TestHiveSingleton { + import hiveContext._ + test("SPARK-5324 query result of describe command") { - loadTestTable("src") + hiveContext.loadTestTable("src") // register a describe command to be a temp table sql("desc src").registerTempTable("mydesc") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala index ba56a8a6b6..cd055f9eca 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala @@ -21,11 +21,11 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.QueryTest import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.expressions.Window -import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHiveSingleton -class HivePlanTest extends QueryTest { - import TestHive._ - import TestHive.implicits._ +class HivePlanTest extends QueryTest with TestHiveSingleton { + import hiveContext.sql + import hiveContext.implicits._ test("udf constant folding") { Seq.empty[Tuple1[Int]].toDF("a").registerTempTable("t") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala index 9c10ffe111..d9ba895e1e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala @@ -28,7 +28,7 @@ import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectIns import org.apache.hadoop.hive.serde2.{AbstractSerDe, SerDeStats} import org.apache.hadoop.io.Writable import org.apache.spark.sql.{AnalysisException, QueryTest, Row, SQLConf} -import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.util.Utils @@ -43,10 +43,10 @@ case class ListStringCaseClass(l: Seq[String]) /** * A test suite for Hive custom UDFs. */ -class HiveUDFSuite extends QueryTest { +class HiveUDFSuite extends QueryTest with TestHiveSingleton { - import TestHive.{udf, sql} - import TestHive.implicits._ + import hiveContext.{udf, sql} + import hiveContext.implicits._ test("spark sql udf test that returns a struct") { udf.register("getStruct", (_: Int) => Fields(1, 2, 3, 4, 5)) @@ -123,12 +123,12 @@ class HiveUDFSuite extends QueryTest { | "value", value)).value FROM src """.stripMargin), Seq(Row("val_0"))) } - val codegenDefault = TestHive.getConf(SQLConf.CODEGEN_ENABLED) - TestHive.setConf(SQLConf.CODEGEN_ENABLED, true) + val codegenDefault = hiveContext.getConf(SQLConf.CODEGEN_ENABLED) + hiveContext.setConf(SQLConf.CODEGEN_ENABLED, true) testOrderInStruct() - TestHive.setConf(SQLConf.CODEGEN_ENABLED, false) + hiveContext.setConf(SQLConf.CODEGEN_ENABLED, false) testOrderInStruct() - TestHive.setConf(SQLConf.CODEGEN_ENABLED, codegenDefault) + hiveContext.setConf(SQLConf.CODEGEN_ENABLED, codegenDefault) } test("SPARK-6409 UDAFAverage test") { @@ -137,7 +137,7 @@ class HiveUDFSuite extends QueryTest { sql("SELECT test_avg(1), test_avg(substr(value,5)) FROM src"), Seq(Row(1.0, 260.182))) sql("DROP TEMPORARY FUNCTION IF EXISTS test_avg") - TestHive.reset() + hiveContext.reset() } test("SPARK-2693 udaf aggregates test") { @@ -157,7 +157,7 @@ class HiveUDFSuite extends QueryTest { } test("UDFIntegerToString") { - val testData = TestHive.sparkContext.parallelize( + val testData = hiveContext.sparkContext.parallelize( IntegerCaseClass(1) :: IntegerCaseClass(2) :: Nil).toDF() testData.registerTempTable("integerTable") @@ -168,11 +168,11 @@ class HiveUDFSuite extends QueryTest { Seq(Row("1"), Row("2"))) sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFIntegerToString") - TestHive.reset() + hiveContext.reset() } test("UDFToListString") { - val testData = TestHive.sparkContext.parallelize(StringCaseClass("") :: Nil).toDF() + val testData = hiveContext.sparkContext.parallelize(StringCaseClass("") :: Nil).toDF() testData.registerTempTable("inputTable") sql(s"CREATE TEMPORARY FUNCTION testUDFToListString AS '${classOf[UDFToListString].getName}'") @@ -183,11 +183,11 @@ class HiveUDFSuite extends QueryTest { "JVM type erasure makes spark fail to catch a component type in List<>;") sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFToListString") - TestHive.reset() + hiveContext.reset() } test("UDFToListInt") { - val testData = TestHive.sparkContext.parallelize(StringCaseClass("") :: Nil).toDF() + val testData = hiveContext.sparkContext.parallelize(StringCaseClass("") :: Nil).toDF() testData.registerTempTable("inputTable") sql(s"CREATE TEMPORARY FUNCTION testUDFToListInt AS '${classOf[UDFToListInt].getName}'") @@ -198,11 +198,11 @@ class HiveUDFSuite extends QueryTest { "JVM type erasure makes spark fail to catch a component type in List<>;") sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFToListInt") - TestHive.reset() + hiveContext.reset() } test("UDFToStringIntMap") { - val testData = TestHive.sparkContext.parallelize(StringCaseClass("") :: Nil).toDF() + val testData = hiveContext.sparkContext.parallelize(StringCaseClass("") :: Nil).toDF() testData.registerTempTable("inputTable") sql(s"CREATE TEMPORARY FUNCTION testUDFToStringIntMap " + @@ -214,11 +214,11 @@ class HiveUDFSuite extends QueryTest { "JVM type erasure makes spark fail to catch key and value types in Map<>;") sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFToStringIntMap") - TestHive.reset() + hiveContext.reset() } test("UDFToIntIntMap") { - val testData = TestHive.sparkContext.parallelize(StringCaseClass("") :: Nil).toDF() + val testData = hiveContext.sparkContext.parallelize(StringCaseClass("") :: Nil).toDF() testData.registerTempTable("inputTable") sql(s"CREATE TEMPORARY FUNCTION testUDFToIntIntMap " + @@ -230,11 +230,11 @@ class HiveUDFSuite extends QueryTest { "JVM type erasure makes spark fail to catch key and value types in Map<>;") sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFToIntIntMap") - TestHive.reset() + hiveContext.reset() } test("UDFListListInt") { - val testData = TestHive.sparkContext.parallelize( + val testData = hiveContext.sparkContext.parallelize( ListListIntCaseClass(Nil) :: ListListIntCaseClass(Seq((1, 2, 3))) :: ListListIntCaseClass(Seq((4, 5, 6), (7, 8, 9))) :: Nil).toDF() @@ -246,11 +246,11 @@ class HiveUDFSuite extends QueryTest { Seq(Row(0), Row(2), Row(13))) sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFListListInt") - TestHive.reset() + hiveContext.reset() } test("UDFListString") { - val testData = TestHive.sparkContext.parallelize( + val testData = hiveContext.sparkContext.parallelize( ListStringCaseClass(Seq("a", "b", "c")) :: ListStringCaseClass(Seq("d", "e")) :: Nil).toDF() testData.registerTempTable("listStringTable") @@ -261,11 +261,11 @@ class HiveUDFSuite extends QueryTest { Seq(Row("a,b,c"), Row("d,e"))) sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFListString") - TestHive.reset() + hiveContext.reset() } test("UDFStringString") { - val testData = TestHive.sparkContext.parallelize( + val testData = hiveContext.sparkContext.parallelize( StringCaseClass("world") :: StringCaseClass("goodbye") :: Nil).toDF() testData.registerTempTable("stringTable") @@ -280,11 +280,11 @@ class HiveUDFSuite extends QueryTest { sql("DROP TEMPORARY FUNCTION IF EXISTS testStringStringUDF") - TestHive.reset() + hiveContext.reset() } test("UDFTwoListList") { - val testData = TestHive.sparkContext.parallelize( + val testData = hiveContext.sparkContext.parallelize( ListListIntCaseClass(Nil) :: ListListIntCaseClass(Seq((1, 2, 3))) :: ListListIntCaseClass(Seq((4, 5, 6), (7, 8, 9))) :: @@ -297,7 +297,7 @@ class HiveUDFSuite extends QueryTest { Seq(Row("0, 0"), Row("2, 2"), Row("13, 13"))) sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFTwoListList") - TestHive.reset() + hiveContext.reset() } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 1ff1d9a293..8126d02335 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -26,9 +26,7 @@ import org.apache.spark.sql.catalyst.DefaultParserDialect import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, EliminateSubQueries} import org.apache.spark.sql.catalyst.errors.DialectException import org.apache.spark.sql.execution.datasources.LogicalRelation -import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.hive.test.TestHive._ -import org.apache.spark.sql.hive.test.TestHive.implicits._ +import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.hive.{HiveContext, HiveQLDialect, MetastoreRelation} import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.test.SQLTestUtils @@ -65,12 +63,12 @@ class MyDialect extends DefaultParserDialect * Hive to generate them (in contrast to HiveQuerySuite). Often this is because the query is * valid, but Hive currently cannot execute it. */ -class SQLQuerySuite extends QueryTest with SQLTestUtils { - override def _sqlContext: SQLContext = TestHive - private val sqlContext = _sqlContext +class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { + import hiveContext._ + import hiveContext.implicits._ test("UDTF") { - sql(s"ADD JAR ${TestHive.getHiveFile("TestUDTF.jar").getCanonicalPath()}") + sql(s"ADD JAR ${hiveContext.getHiveFile("TestUDTF.jar").getCanonicalPath()}") // The function source code can be found at: // https://cwiki.apache.org/confluence/display/Hive/DeveloperGuide+UDTF sql( @@ -509,19 +507,19 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { checkAnswer( sql("SELECT f1.f2.f3 FROM nested"), Row(1)) - checkAnswer(sql("CREATE TABLE test_ctas_1234 AS SELECT * from nested"), - Seq.empty[Row]) + + sql("CREATE TABLE test_ctas_1234 AS SELECT * from nested") checkAnswer( sql("SELECT * FROM test_ctas_1234"), sql("SELECT * FROM nested").collect().toSeq) intercept[AnalysisException] { - sql("CREATE TABLE test_ctas_12345 AS SELECT * from notexists").collect() + sql("CREATE TABLE test_ctas_1234 AS SELECT * from notexists").collect() } } test("test CTAS") { - checkAnswer(sql("CREATE TABLE test_ctas_123 AS SELECT key, value FROM src"), Seq.empty[Row]) + sql("CREATE TABLE test_ctas_123 AS SELECT key, value FROM src") checkAnswer( sql("SELECT key, value FROM test_ctas_123 ORDER BY key"), sql("SELECT key, value FROM src ORDER BY key").collect().toSeq) @@ -614,7 +612,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { val rowRdd = sparkContext.parallelize(row :: Nil) - TestHive.createDataFrame(rowRdd, schema).registerTempTable("testTable") + hiveContext.createDataFrame(rowRdd, schema).registerTempTable("testTable") sql( """CREATE TABLE nullValuesInInnerComplexTypes @@ -1044,10 +1042,10 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { val thread = new Thread { override def run() { // To make sure this test works, this jar should not be loaded in another place. - TestHive.sql( - s"ADD JAR ${TestHive.getHiveFile("hive-contrib-0.13.1.jar").getCanonicalPath()}") + sql( + s"ADD JAR ${hiveContext.getHiveFile("hive-contrib-0.13.1.jar").getCanonicalPath()}") try { - TestHive.sql( + sql( """ |CREATE TEMPORARY FUNCTION example_max |AS 'org.apache.hadoop.hive.contrib.udaf.example.UDAFExampleMax' @@ -1097,21 +1095,21 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { test("SPARK-8588 HiveTypeCoercion.inConversion fires too early") { val df = - TestHive.createDataFrame(Seq((1, "2014-01-01"), (2, "2015-01-01"), (3, "2016-01-01"))) + createDataFrame(Seq((1, "2014-01-01"), (2, "2015-01-01"), (3, "2016-01-01"))) df.toDF("id", "datef").registerTempTable("test_SPARK8588") checkAnswer( - TestHive.sql( + sql( """ |select id, concat(year(datef)) |from test_SPARK8588 where concat(year(datef), ' year') in ('2015 year', '2014 year') """.stripMargin), Row(1, "2014") :: Row(2, "2015") :: Nil ) - TestHive.dropTempTable("test_SPARK8588") + dropTempTable("test_SPARK8588") } test("SPARK-9371: fix the support for special chars in column names for hive context") { - TestHive.read.json(TestHive.sparkContext.makeRDD( + read.json(sparkContext.makeRDD( """{"a": {"c.b": 1}, "b.$q": [{"a@!.q": 1}], "q.w": {"w.i&": [1]}}""" :: Nil)) .registerTempTable("t") @@ -1142,8 +1140,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { test("specifying database name for a temporary table is not allowed") { withTempPath { dir => val path = dir.getCanonicalPath - val df = - sqlContext.sparkContext.parallelize(1 to 10).map(i => (i, i.toString)).toDF("num", "str") + val df = sparkContext.parallelize(1 to 10).map(i => (i, i.toString)).toDF("num", "str") df .write .format("parquet") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala index 9aca40f15a..cb8d0fca8e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala @@ -22,17 +22,14 @@ import org.scalatest.exceptions.TestFailedException import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD -import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.execution.{UnaryNode, SparkPlan, SparkPlanTest} -import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.types.StringType -class ScriptTransformationSuite extends SparkPlanTest { - - override def _sqlContext: SQLContext = TestHive - private val sqlContext = _sqlContext +class ScriptTransformationSuite extends SparkPlanTest with TestHiveSingleton { + import hiveContext.implicits._ private val noSerdeIOSchema = HiveScriptIOSchema( inputRowFormat = Seq.empty, @@ -59,7 +56,7 @@ class ScriptTransformationSuite extends SparkPlanTest { output = Seq(AttributeReference("a", StringType)()), child = child, ioschema = noSerdeIOSchema - )(TestHive), + )(hiveContext), rowsDf.collect()) } @@ -73,7 +70,7 @@ class ScriptTransformationSuite extends SparkPlanTest { output = Seq(AttributeReference("a", StringType)()), child = child, ioschema = serdeIOSchema - )(TestHive), + )(hiveContext), rowsDf.collect()) } @@ -88,7 +85,7 @@ class ScriptTransformationSuite extends SparkPlanTest { output = Seq(AttributeReference("a", StringType)()), child = ExceptionInjectingOperator(child), ioschema = noSerdeIOSchema - )(TestHive), + )(hiveContext), rowsDf.collect()) } assert(e.getMessage().contains("intentional exception")) @@ -105,7 +102,7 @@ class ScriptTransformationSuite extends SparkPlanTest { output = Seq(AttributeReference("a", StringType)()), child = ExceptionInjectingOperator(child), ioschema = serdeIOSchema - )(TestHive), + )(hiveContext), rowsDf.collect()) } assert(e.getMessage().contains("intentional exception")) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala index deec0048d2..9a299c3f9d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala @@ -24,10 +24,9 @@ import org.apache.spark.sql.sources.HadoopFsRelationTest import org.apache.spark.sql.types._ class OrcHadoopFsRelationSuite extends HadoopFsRelationTest { - override val dataSourceName: String = classOf[DefaultSource].getCanonicalName + import testImplicits._ - import sqlContext._ - import sqlContext.implicits._ + override val dataSourceName: String = classOf[DefaultSource].getCanonicalName test("save()/load() - partitioned table - simple queries - partition columns in data") { withTempDir { file => @@ -48,7 +47,7 @@ class OrcHadoopFsRelationSuite extends HadoopFsRelationTest { StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true)) checkQueries( - read.options(Map( + hiveContext.read.options(Map( "path" -> file.getCanonicalPath, "dataSchema" -> dataSchemaWithPartition.json)).format(dataSourceName).load()) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala index a46ca9a2c9..52e09f9496 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala @@ -18,19 +18,17 @@ package org.apache.spark.sql.hive.orc import java.io.File -import org.apache.hadoop.hive.conf.HiveConf.ConfVars -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.hive.test.TestHive._ -import org.apache.spark.sql.hive.test.TestHive.implicits._ -import org.apache.spark.util.Utils -import org.scalatest.BeforeAndAfterAll import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag +import org.scalatest.BeforeAndAfterAll +import org.apache.hadoop.hive.conf.HiveConf.ConfVars + +import org.apache.spark.sql._ +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.util.Utils + // The data where the partitioning key exists only in the directory structure. case class OrcParData(intField: Int, stringField: String) @@ -38,7 +36,10 @@ case class OrcParData(intField: Int, stringField: String) case class OrcParDataWithKey(intField: Int, pi: Int, stringField: String, ps: String) // TODO This test suite duplicates ParquetPartitionDiscoverySuite a lot -class OrcPartitionDiscoverySuite extends QueryTest with BeforeAndAfterAll { +class OrcPartitionDiscoverySuite extends QueryTest with TestHiveSingleton with BeforeAndAfterAll { + import hiveContext._ + import hiveContext.implicits._ + val defaultPartitionName = ConfVars.DEFAULTPARTITIONNAME.defaultStrVal def withTempDir(f: File => Unit): Unit = { @@ -58,7 +59,7 @@ class OrcPartitionDiscoverySuite extends QueryTest with BeforeAndAfterAll { } protected def withTempTable(tableName: String)(f: => Unit): Unit = { - try f finally TestHive.dropTempTable(tableName) + try f finally hiveContext.dropTempTable(tableName) } protected def makePartitionDir( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala index 80c38084f2..7a34cf731b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala @@ -21,12 +21,14 @@ import java.io.File import org.scalatest.BeforeAndAfterAll -import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.hive.test.TestHiveSingleton case class OrcData(intField: Int, stringField: String) -abstract class OrcSuite extends QueryTest with BeforeAndAfterAll { +abstract class OrcSuite extends QueryTest with TestHiveSingleton with BeforeAndAfterAll { + import hiveContext._ + var orcTableDir: File = null var orcTableAsDir: File = null @@ -156,7 +158,7 @@ class OrcSourceSuite extends OrcSuite { override def beforeAll(): Unit = { super.beforeAll() - sql( + hiveContext.sql( s"""CREATE TEMPORARY TABLE normal_orc_source |USING org.apache.spark.sql.hive.orc |OPTIONS ( @@ -164,7 +166,7 @@ class OrcSourceSuite extends OrcSuite { |) """.stripMargin) - sql( + hiveContext.sql( s"""CREATE TEMPORARY TABLE normal_orc_as_source |USING org.apache.spark.sql.hive.orc |OPTIONS ( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala index f7ba20ff41..88a0ed5117 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala @@ -22,15 +22,12 @@ import java.io.File import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag -import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.hive.test.TestHiveSingleton -private[sql] trait OrcTest extends SQLTestUtils { this: SparkFunSuite => - protected override def _sqlContext: SQLContext = org.apache.spark.sql.hive.test.TestHive - protected val sqlContext = _sqlContext - import sqlContext.implicits._ - import sqlContext.sparkContext +private[sql] trait OrcTest extends SQLTestUtils with TestHiveSingleton { + import testImplicits._ /** * Writes `data` to a Orc file, which is then passed to `f` and will be deleted after `f` diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index 34d3434569..6842ec2b5e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -19,15 +19,11 @@ package org.apache.spark.sql.hive import java.io.File -import org.scalatest.BeforeAndAfterAll - import org.apache.spark.sql._ import org.apache.spark.sql.execution.datasources.{InsertIntoDataSource, InsertIntoHadoopFsRelation, LogicalRelation} import org.apache.spark.sql.execution.{ExecutedCommand, PhysicalRDD} import org.apache.spark.sql.hive.execution.HiveTableScan -import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.hive.test.TestHive._ -import org.apache.spark.sql.hive.test.TestHive.implicits._ +import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ @@ -58,6 +54,8 @@ case class ParquetDataWithKeyAndComplexTypes( * built in parquet support. */ class ParquetMetastoreSuite extends ParquetPartitioningTest { + import hiveContext._ + override def beforeAll(): Unit = { super.beforeAll() dropTables("partitioned_parquet", @@ -536,6 +534,9 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { * A suite of tests for the Parquet support through the data sources API. */ class ParquetSourceSuite extends ParquetPartitioningTest { + import testImplicits._ + import hiveContext._ + override def beforeAll(): Unit = { super.beforeAll() dropTables("partitioned_parquet", @@ -684,9 +685,8 @@ class ParquetSourceSuite extends ParquetPartitioningTest { /** * A collection of tests for parquet data with various forms of partitioning. */ -abstract class ParquetPartitioningTest extends QueryTest with SQLTestUtils with BeforeAndAfterAll { - override def _sqlContext: SQLContext = TestHive - protected val sqlContext = _sqlContext +abstract class ParquetPartitioningTest extends QueryTest with SQLTestUtils with TestHiveSingleton { + import testImplicits._ var partitionedTableDir: File = null var normalTableDir: File = null diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala index b4640b1616..dc0531a6d4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala @@ -18,16 +18,13 @@ package org.apache.spark.sql.sources import org.apache.hadoop.fs.Path -import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.SparkException import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils -class CommitFailureTestRelationSuite extends SparkFunSuite with SQLTestUtils { - override def _sqlContext: SQLContext = TestHive - private val sqlContext = _sqlContext +class CommitFailureTestRelationSuite extends SQLTestUtils with TestHiveSingleton { // When committing a task, `CommitFailureTestSource` throws an exception for testing purpose. val dataSourceName: String = classOf[CommitFailureTestSource].getCanonicalName diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala index 8ca3a17085..1945b15002 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala @@ -28,8 +28,6 @@ import org.apache.spark.sql.types._ class JsonHadoopFsRelationSuite extends HadoopFsRelationTest { override val dataSourceName: String = "json" - import sqlContext._ - test("save()/load() - partitioned table - simple queries - partition columns in data") { withTempDir { file => val basePath = new Path(file.getCanonicalPath) @@ -47,7 +45,7 @@ class JsonHadoopFsRelationSuite extends HadoopFsRelationTest { StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true)) checkQueries( - read.format(dataSourceName) + hiveContext.read.format(dataSourceName) .option("dataSchema", dataSchemaWithPartition.json) .load(file.getCanonicalPath)) } @@ -65,14 +63,14 @@ class JsonHadoopFsRelationSuite extends HadoopFsRelationTest { val data = Row(Seq(1L, 2L, 3L), Map("m1" -> Row(4L))) :: Row(Seq(5L, 6L, 7L), Map("m2" -> Row(10L))) :: Nil - val df = createDataFrame(sparkContext.parallelize(data), schema) + val df = hiveContext.createDataFrame(sparkContext.parallelize(data), schema) // Write the data out. df.write.format(dataSourceName).save(file.getCanonicalPath) // Read it back and check the result. checkAnswer( - read.format(dataSourceName).schema(schema).load(file.getCanonicalPath), + hiveContext.read.format(dataSourceName).schema(schema).load(file.getCanonicalPath), df ) } @@ -90,14 +88,14 @@ class JsonHadoopFsRelationSuite extends HadoopFsRelationTest { Row(new BigDecimal("10.02")) :: Row(new BigDecimal("20000.99")) :: Row(new BigDecimal("10000")) :: Nil - val df = createDataFrame(sparkContext.parallelize(data), schema) + val df = hiveContext.createDataFrame(sparkContext.parallelize(data), schema) // Write the data out. df.write.format(dataSourceName).save(file.getCanonicalPath) // Read it back and check the result. checkAnswer( - read.format(dataSourceName).schema(schema).load(file.getCanonicalPath), + hiveContext.read.format(dataSourceName).schema(schema).load(file.getCanonicalPath), df ) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala index 06dadbb5fe..08c3c17973 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala @@ -28,10 +28,9 @@ import org.apache.spark.sql.types.{IntegerType, StructField, StructType} class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { - override val dataSourceName: String = "parquet" + import testImplicits._ - import sqlContext._ - import sqlContext.implicits._ + override val dataSourceName: String = "parquet" test("save()/load() - partitioned table - simple queries - partition columns in data") { withTempDir { file => @@ -51,7 +50,7 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true)) checkQueries( - read.format(dataSourceName) + hiveContext.read.format(dataSourceName) .option("dataSchema", dataSchemaWithPartition.json) .load(file.getCanonicalPath)) } @@ -69,7 +68,7 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { .format("parquet") .save(s"${dir.getCanonicalPath}/_temporary") - checkAnswer(read.format("parquet").load(dir.getCanonicalPath), df.collect()) + checkAnswer(hiveContext.read.format("parquet").load(dir.getCanonicalPath), df.collect()) } } @@ -97,7 +96,7 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { // This shouldn't throw anything. df.write.format("parquet").mode(SaveMode.Overwrite).save(path) - checkAnswer(read.format("parquet").load(path), df) + checkAnswer(hiveContext.read.format("parquet").load(path), df) } } @@ -107,7 +106,7 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { // Parquet doesn't allow field names with spaces. Here we are intentionally making an // exception thrown from the `ParquetRelation2.prepareForWriteJob()` method to trigger // the bug. Please refer to spark-8079 for more details. - range(1, 10) + hiveContext.range(1, 10) .withColumnRenamed("id", "a b") .write .format("parquet") @@ -125,7 +124,7 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { val summaryPath = new Path(path, "_metadata") val commonSummaryPath = new Path(path, "_common_metadata") - val fs = summaryPath.getFileSystem(configuration) + val fs = summaryPath.getFileSystem(hadoopConfiguration) fs.delete(summaryPath, true) fs.delete(commonSummaryPath, true) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala index e8975e5f5c..1125ca6701 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala @@ -25,8 +25,6 @@ import org.apache.spark.sql.types.{IntegerType, StructField, StructType} class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest { override val dataSourceName: String = classOf[SimpleTextSource].getCanonicalName - import sqlContext._ - test("save()/load() - partitioned table - simple queries - partition columns in data") { withTempDir { file => val basePath = new Path(file.getCanonicalPath) @@ -44,7 +42,7 @@ class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest { StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true)) checkQueries( - read.format(dataSourceName) + hiveContext.read.format(dataSourceName) .option("dataSchema", dataSchemaWithPartition.json) .load(file.getCanonicalPath)) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala index 7966b43596..2ad2618dfc 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -28,14 +28,12 @@ import org.apache.parquet.hadoop.ParquetOutputCommitter import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql._ import org.apache.spark.sql.execution.datasources.LogicalRelation -import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ -abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { - override def _sqlContext: SQLContext = TestHive - protected val sqlContext = _sqlContext +abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with TestHiveSingleton { import sqlContext.implicits._ val dataSourceName: String @@ -504,17 +502,17 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { } test("SPARK-8578 specified custom output committer will not be used to append data") { - val clonedConf = new Configuration(configuration) + val clonedConf = new Configuration(hadoopConfiguration) try { val df = sqlContext.range(1, 10).toDF("i") withTempPath { dir => df.write.mode("append").format(dataSourceName).save(dir.getCanonicalPath) - configuration.set( + hadoopConfiguration.set( SQLConf.OUTPUT_COMMITTER_CLASS.key, classOf[AlwaysFailOutputCommitter].getName) // Since Parquet has its own output committer setting, also set it // to AlwaysFailParquetOutputCommitter at here. - configuration.set("spark.sql.parquet.output.committer.class", + hadoopConfiguration.set("spark.sql.parquet.output.committer.class", classOf[AlwaysFailParquetOutputCommitter].getName) // Because there data already exists, // this append should succeed because we will use the output committer associated @@ -533,12 +531,12 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { } } withTempPath { dir => - configuration.set( + hadoopConfiguration.set( SQLConf.OUTPUT_COMMITTER_CLASS.key, classOf[AlwaysFailOutputCommitter].getName) // Since Parquet has its own output committer setting, also set it // to AlwaysFailParquetOutputCommitter at here. - configuration.set("spark.sql.parquet.output.committer.class", + hadoopConfiguration.set("spark.sql.parquet.output.committer.class", classOf[AlwaysFailParquetOutputCommitter].getName) // Because there is no existing data, // this append will fail because AlwaysFailOutputCommitter is used when we do append @@ -549,8 +547,8 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { } } finally { // Hadoop 1 doesn't have `Configuration.unset` - configuration.clear() - clonedConf.asScala.foreach(entry => configuration.set(entry.getKey, entry.getValue)) + hadoopConfiguration.clear() + clonedConf.asScala.foreach(entry => hadoopConfiguration.set(entry.getKey, entry.getValue)) } } @@ -570,7 +568,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { } test("SPARK-9899 Disable customized output committer when speculation is on") { - val clonedConf = new Configuration(configuration) + val clonedConf = new Configuration(hadoopConfiguration) val speculationEnabled = sqlContext.sparkContext.conf.getBoolean("spark.speculation", defaultValue = false) @@ -580,7 +578,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { sqlContext.sparkContext.conf.set("spark.speculation", "true") // Uses a customized output committer which always fails - configuration.set( + hadoopConfiguration.set( SQLConf.OUTPUT_COMMITTER_CLASS.key, classOf[AlwaysFailOutputCommitter].getName) @@ -597,8 +595,8 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { } } finally { // Hadoop 1 doesn't have `Configuration.unset` - configuration.clear() - clonedConf.asScala.foreach(entry => configuration.set(entry.getKey, entry.getValue)) + hadoopConfiguration.clear() + clonedConf.asScala.foreach(entry => hadoopConfiguration.set(entry.getKey, entry.getValue)) sqlContext.sparkContext.conf.set("spark.speculation", speculationEnabled.toString) } }