Skip to content

Commit

Permalink
Add tests for proper cleanup of shuffle data.
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshRosen committed May 13, 2015
1 parent d494ffe commit 7610f2f
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager
true
}

override def shuffleBlockResolver: IndexShuffleBlockResolver = {
override val shuffleBlockResolver: IndexShuffleBlockResolver = {
indexShuffleBlockResolver
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@

package org.apache.spark.shuffle.unsafe

import java.util.Collections
import java.util.concurrent.ConcurrentHashMap

import org.apache.spark._
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle._
Expand All @@ -25,7 +28,7 @@ import org.apache.spark.shuffle.sort.SortShuffleManager
/**
* Subclass of [[BaseShuffleHandle]], used to identify when we've chosen to use the new shuffle.
*/
private class UnsafeShuffleHandle[K, V](
private[spark] class UnsafeShuffleHandle[K, V](
shuffleId: Int,
numMaps: Int,
dependency: ShuffleDependency[K, V, V])
Expand Down Expand Up @@ -121,8 +124,10 @@ private[spark] class UnsafeShuffleManager(conf: SparkConf) extends ShuffleManage
"manager; its optimized shuffles will continue to spill to disk when necessary.")
}


private[this] val sortShuffleManager: SortShuffleManager = new SortShuffleManager(conf)
private[this] val shufflesThatFellBackToSortShuffle =
Collections.newSetFromMap(new ConcurrentHashMap[Int, java.lang.Boolean]())
private[this] val numMapsForShufflesThatUsedNewPath = new ConcurrentHashMap[Int, Int]()

/**
* Register a shuffle with the manager and obtain a handle for it to pass to tasks.
Expand Down Expand Up @@ -158,8 +163,8 @@ private[spark] class UnsafeShuffleManager(conf: SparkConf) extends ShuffleManage
context: TaskContext): ShuffleWriter[K, V] = {
handle match {
case unsafeShuffleHandle: UnsafeShuffleHandle[K, V] =>
numMapsForShufflesThatUsedNewPath.putIfAbsent(handle.shuffleId, unsafeShuffleHandle.numMaps)
val env = SparkEnv.get
// TODO: do we need to do anything to register the shuffle here?
new UnsafeShuffleWriter(
env.blockManager,
shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver],
Expand All @@ -170,17 +175,26 @@ private[spark] class UnsafeShuffleManager(conf: SparkConf) extends ShuffleManage
context,
env.conf)
case other =>
shufflesThatFellBackToSortShuffle.add(handle.shuffleId)
sortShuffleManager.getWriter(handle, mapId, context)
}
}

/** Remove a shuffle's metadata from the ShuffleManager. */
override def unregisterShuffle(shuffleId: Int): Boolean = {
// TODO: need to do something here for our unsafe path
sortShuffleManager.unregisterShuffle(shuffleId)
if (shufflesThatFellBackToSortShuffle.remove(shuffleId)) {
sortShuffleManager.unregisterShuffle(shuffleId)
} else {
Option(numMapsForShufflesThatUsedNewPath.remove(shuffleId)).foreach { numMaps =>
(0 until numMaps).foreach { mapId =>
shuffleBlockResolver.removeDataByMap(shuffleId, mapId)
}
}
true
}
}

override def shuffleBlockResolver: ShuffleBlockResolver = {
override val shuffleBlockResolver: IndexShuffleBlockResolver = {
sortShuffleManager.shuffleBlockResolver
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,17 @@

package org.apache.spark.shuffle.unsafe

import org.apache.spark.ShuffleSuite
import scala.collection.JavaConverters._

import org.apache.commons.io.FileUtils
import org.apache.commons.io.filefilter.TrueFileFilter
import org.scalatest.BeforeAndAfterAll

import org.apache.spark.{HashPartitioner, ShuffleDependency, SparkContext, ShuffleSuite}
import org.apache.spark.rdd.ShuffledRDD
import org.apache.spark.serializer.{JavaSerializer, KryoSerializer}
import org.apache.spark.util.Utils

class UnsafeShuffleSuite extends ShuffleSuite with BeforeAndAfterAll {

// This test suite should run all tests in ShuffleSuite with unsafe-based shuffle.
Expand All @@ -30,4 +38,66 @@ class UnsafeShuffleSuite extends ShuffleSuite with BeforeAndAfterAll {
// shuffle records.
conf.set("spark.shuffle.memoryFraction", "0.5")
}

test("UnsafeShuffleManager properly cleans up files for shuffles that use the new shuffle path") {
val tmpDir = Utils.createTempDir()
try {
val myConf = conf.clone()
.set("spark.local.dir", tmpDir.getAbsolutePath)
sc = new SparkContext("local", "test", myConf)
// Create a shuffled RDD and verify that it will actually use the new UnsafeShuffle path
val rdd = sc.parallelize(1 to 10, 1).map(x => (x, x))
val shuffledRdd = new ShuffledRDD[Int, Int, Int](rdd, new HashPartitioner(4))
.setSerializer(new KryoSerializer(myConf))
val shuffleDep = shuffledRdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]]
assert(UnsafeShuffleManager.canUseUnsafeShuffle(shuffleDep))
def getAllFiles =
FileUtils.listFiles(tmpDir, TrueFileFilter.INSTANCE, TrueFileFilter.INSTANCE).asScala.toSet
val filesBeforeShuffle = getAllFiles
// Force the shuffle to be performed
shuffledRdd.count()
// Ensure that the shuffle actually created files that will need to be cleaned up
val filesCreatedByShuffle = getAllFiles -- filesBeforeShuffle
filesCreatedByShuffle.map(_.getName) should be
Set("shuffle_0_0_0.data", "shuffle_0_0_0.index")
// Check that the cleanup actually removes the files
sc.env.blockManager.master.removeShuffle(0, blocking = true)
for (file <- filesCreatedByShuffle) {
assert (!file.exists(), s"Shuffle file $file was not cleaned up")
}
} finally {
Utils.deleteRecursively(tmpDir)
}
}

test("UnsafeShuffleManager properly cleans up files for shuffles that use the old shuffle path") {
val tmpDir = Utils.createTempDir()
try {
val myConf = conf.clone()
.set("spark.local.dir", tmpDir.getAbsolutePath)
sc = new SparkContext("local", "test", myConf)
// Create a shuffled RDD and verify that it will actually use the old SortShuffle path
val rdd = sc.parallelize(1 to 10, 1).map(x => (x, x))
val shuffledRdd = new ShuffledRDD[Int, Int, Int](rdd, new HashPartitioner(4))
.setSerializer(new JavaSerializer(myConf))
val shuffleDep = shuffledRdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]]
assert(!UnsafeShuffleManager.canUseUnsafeShuffle(shuffleDep))
def getAllFiles =
FileUtils.listFiles(tmpDir, TrueFileFilter.INSTANCE, TrueFileFilter.INSTANCE).asScala.toSet
val filesBeforeShuffle = getAllFiles
// Force the shuffle to be performed
shuffledRdd.count()
// Ensure that the shuffle actually created files that will need to be cleaned up
val filesCreatedByShuffle = getAllFiles -- filesBeforeShuffle
filesCreatedByShuffle.map(_.getName) should be
Set("shuffle_0_0_0.data", "shuffle_0_0_0.index")
// Check that the cleanup actually removes the files
sc.env.blockManager.master.removeShuffle(0, blocking = true)
for (file <- filesCreatedByShuffle) {
assert (!file.exists(), s"Shuffle file $file was not cleaned up")
}
} finally {
Utils.deleteRecursively(tmpDir)
}
}
}

0 comments on commit 7610f2f

Please sign in to comment.