diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 6197f10813a3b..eb8700369275e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1584,6 +1584,7 @@ class DataFrame private[sql]( def distinct(): DataFrame = dropDuplicates() /** + * Persist this [[DataFrame]] with the default storage level (`MEMORY_AND_DISK`). * @group basic * @since 1.3.0 */ @@ -1593,12 +1594,17 @@ class DataFrame private[sql]( } /** + * Persist this [[DataFrame]] with the default storage level (`MEMORY_AND_DISK`). * @group basic * @since 1.3.0 */ def cache(): this.type = persist() /** + * Persist this [[DataFrame]] with the given storage level. + * @param newLevel One of: `MEMORY_ONLY`, `MEMORY_AND_DISK`, `MEMORY_ONLY_SER`, + * `MEMORY_AND_DISK_SER`, `DISK_ONLY`, `MEMORY_ONLY_2`, + * `MEMORY_AND_DISK_2`, etc. * @group basic * @since 1.3.0 */ @@ -1608,6 +1614,8 @@ class DataFrame private[sql]( } /** + * Mark the [[DataFrame]] as non-persistent, and remove all blocks for it from memory and disk. + * @param blocking Whether to block until all blocks are deleted. * @group basic * @since 1.3.0 */ @@ -1617,6 +1625,7 @@ class DataFrame private[sql]( } /** + * Mark the [[DataFrame]] as non-persistent, and remove all blocks for it from memory and disk. * @group basic * @since 1.3.0 */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index c357f88a94dd0..d6bb1d2ad8e50 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.{Queryable, QueryExecution} import org.apache.spark.sql.types.StructType +import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils /** @@ -565,7 +566,7 @@ class Dataset[T] private[sql]( * combined. * * Note that, this function is not a typical set union operation, in that it does not eliminate - * duplicate items. As such, it is analagous to `UNION ALL` in SQL. + * duplicate items. As such, it is analogous to `UNION ALL` in SQL. * @since 1.6.0 */ def union(other: Dataset[T]): Dataset[T] = withPlan[T](other)(Union) @@ -618,7 +619,6 @@ class Dataset[T] private[sql]( case _ => Alias(CreateStruct(rightOutput), "_2")() } - implicit val tuple2Encoder: Encoder[(T, U)] = ExpressionEncoder.tuple(this.unresolvedTEncoder, other.unresolvedTEncoder) withPlan[(T, U)](other) { (left, right) => @@ -697,11 +697,55 @@ class Dataset[T] private[sql]( */ def takeAsList(num: Int): java.util.List[T] = java.util.Arrays.asList(take(num) : _*) + /** + * Persist this [[Dataset]] with the default storage level (`MEMORY_AND_DISK`). + * @since 1.6.0 + */ + def persist(): this.type = { + sqlContext.cacheManager.cacheQuery(this) + this + } + + /** + * Persist this [[Dataset]] with the default storage level (`MEMORY_AND_DISK`). + * @since 1.6.0 + */ + def cache(): this.type = persist() + + /** + * Persist this [[Dataset]] with the given storage level. + * @param newLevel One of: `MEMORY_ONLY`, `MEMORY_AND_DISK`, `MEMORY_ONLY_SER`, + * `MEMORY_AND_DISK_SER`, `DISK_ONLY`, `MEMORY_ONLY_2`, + * `MEMORY_AND_DISK_2`, etc. + * @group basic + * @since 1.6.0 + */ + def persist(newLevel: StorageLevel): this.type = { + sqlContext.cacheManager.cacheQuery(this, None, newLevel) + this + } + + /** + * Mark the [[Dataset]] as non-persistent, and remove all blocks for it from memory and disk. + * @param blocking Whether to block until all blocks are deleted. + * @since 1.6.0 + */ + def unpersist(blocking: Boolean): this.type = { + sqlContext.cacheManager.tryUncacheQuery(this, blocking) + this + } + + /** + * Mark the [[Dataset]] as non-persistent, and remove all blocks for it from memory and disk. + * @since 1.6.0 + */ + def unpersist(): this.type = unpersist(blocking = false) + /* ******************** * * Internal Functions * * ******************** */ - private[sql] def logicalPlan = queryExecution.analyzed + private[sql] def logicalPlan: LogicalPlan = queryExecution.analyzed private[sql] def withPlan(f: LogicalPlan => LogicalPlan): Dataset[T] = new Dataset[T](sqlContext, sqlContext.executePlan(f(logicalPlan)), tEncoder) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 9cc65de19180a..4e26250868374 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -338,6 +338,15 @@ class SQLContext private[sql]( cacheManager.lookupCachedData(table(tableName)).nonEmpty } + /** + * Returns true if the [[Queryable]] is currently cached in-memory. + * @group cachemgmt + * @since 1.3.0 + */ + private[sql] def isCached(qName: Queryable): Boolean = { + cacheManager.lookupCachedData(qName).nonEmpty + } + /** * Caches the specified table in-memory. * @group cachemgmt diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index 293fcfe96e677..50f6562815c21 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.execution import java.util.concurrent.locks.ReentrantReadWriteLock import org.apache.spark.Logging -import org.apache.spark.sql.DataFrame import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.storage.StorageLevel @@ -75,12 +74,12 @@ private[sql] class CacheManager extends Logging { } /** - * Caches the data produced by the logical representation of the given [[DataFrame]]. Unlike - * `RDD.cache()`, the default storage level is set to be `MEMORY_AND_DISK` because recomputing - * the in-memory columnar representation of the underlying table is expensive. + * Caches the data produced by the logical representation of the given [[Queryable]]. + * Unlike `RDD.cache()`, the default storage level is set to be `MEMORY_AND_DISK` because + * recomputing the in-memory columnar representation of the underlying table is expensive. */ private[sql] def cacheQuery( - query: DataFrame, + query: Queryable, tableName: Option[String] = None, storageLevel: StorageLevel = MEMORY_AND_DISK): Unit = writeLock { val planToCache = query.queryExecution.analyzed @@ -95,13 +94,13 @@ private[sql] class CacheManager extends Logging { sqlContext.conf.useCompression, sqlContext.conf.columnBatchSize, storageLevel, - sqlContext.executePlan(query.logicalPlan).executedPlan, + sqlContext.executePlan(planToCache).executedPlan, tableName)) } } - /** Removes the data for the given [[DataFrame]] from the cache */ - private[sql] def uncacheQuery(query: DataFrame, blocking: Boolean = true): Unit = writeLock { + /** Removes the data for the given [[Queryable]] from the cache */ + private[sql] def uncacheQuery(query: Queryable, blocking: Boolean = true): Unit = writeLock { val planToCache = query.queryExecution.analyzed val dataIndex = cachedData.indexWhere(cd => planToCache.sameResult(cd.plan)) require(dataIndex >= 0, s"Table $query is not cached.") @@ -109,9 +108,11 @@ private[sql] class CacheManager extends Logging { cachedData.remove(dataIndex) } - /** Tries to remove the data for the given [[DataFrame]] from the cache if it's cached */ + /** Tries to remove the data for the given [[Queryable]] from the cache + * if it's cached + */ private[sql] def tryUncacheQuery( - query: DataFrame, + query: Queryable, blocking: Boolean = true): Boolean = writeLock { val planToCache = query.queryExecution.analyzed val dataIndex = cachedData.indexWhere(cd => planToCache.sameResult(cd.plan)) @@ -123,12 +124,12 @@ private[sql] class CacheManager extends Logging { found } - /** Optionally returns cached data for the given [[DataFrame]] */ - private[sql] def lookupCachedData(query: DataFrame): Option[CachedData] = readLock { + /** Optionally returns cached data for the given [[Queryable]] */ + private[sql] def lookupCachedData(query: Queryable): Option[CachedData] = readLock { lookupCachedData(query.queryExecution.analyzed) } - /** Optionally returns cached data for the given LogicalPlan. */ + /** Optionally returns cached data for the given [[LogicalPlan]]. */ private[sql] def lookupCachedData(plan: LogicalPlan): Option[CachedData] = readLock { cachedData.find(cd => plan.sameResult(cd.plan)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala new file mode 100644 index 0000000000000..3a283a4e1f610 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.language.postfixOps + +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext + + +class DatasetCacheSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + test("persist and unpersist") { + val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS().select(expr("_2 + 1").as[Int]) + val cached = ds.cache() + // count triggers the caching action. It should not throw. + cached.count() + // Make sure, the Dataset is indeed cached. + assertCached(cached) + // Check result. + checkAnswer( + cached, + 2, 3, 4) + // Drop the cache. + cached.unpersist() + assert(!sqlContext.isCached(cached), "The Dataset should not be cached.") + } + + test("persist and then rebind right encoder when join 2 datasets") { + val ds1 = Seq("1", "2").toDS().as("a") + val ds2 = Seq(2, 3).toDS().as("b") + + ds1.persist() + assertCached(ds1) + ds2.persist() + assertCached(ds2) + + val joined = ds1.joinWith(ds2, $"a.value" === $"b.value") + checkAnswer(joined, ("2", 2)) + assertCached(joined, 2) + + ds1.unpersist() + assert(!sqlContext.isCached(ds1), "The Dataset ds1 should not be cached.") + ds2.unpersist() + assert(!sqlContext.isCached(ds2), "The Dataset ds2 should not be cached.") + } + + test("persist and then groupBy columns asKey, map") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + val grouped = ds.groupBy($"_1").keyAs[String] + val agged = grouped.mapGroups { case (g, iter) => (g, iter.map(_._2).sum) } + agged.persist() + + checkAnswer( + agged.filter(_._1 == "b"), + ("b", 3)) + assertCached(agged.filter(_._1 == "b")) + + ds.unpersist() + assert(!sqlContext.isCached(ds), "The Dataset ds should not be cached.") + agged.unpersist() + assert(!sqlContext.isCached(agged), "The Dataset agged should not be cached.") + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 8f476dd0f99b6..bc22fb8b7bdb4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -24,6 +24,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.columnar.InMemoryRelation +import org.apache.spark.sql.execution.Queryable abstract class QueryTest extends PlanTest { @@ -163,9 +164,9 @@ abstract class QueryTest extends PlanTest { } /** - * Asserts that a given [[DataFrame]] will be executed using the given number of cached results. + * Asserts that a given [[Queryable]] will be executed using the given number of cached results. */ - def assertCached(query: DataFrame, numCachedTables: Int = 1): Unit = { + def assertCached(query: Queryable, numCachedTables: Int = 1): Unit = { val planWithCaching = query.queryExecution.withCachedData val cachedData = planWithCaching collect { case cached: InMemoryRelation => cached