diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala index 597d46a3d2223..9d8e7e9f03aea 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala @@ -17,29 +17,29 @@ package org.apache.spark.shuffle.hash -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap -import scala.util.{Failure, Success, Try} +import java.io.InputStream + +import scala.collection.mutable.{ArrayBuffer, HashMap} +import scala.util.{Failure, Success} import org.apache.spark._ -import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.FetchFailedException -import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockFetcherIterator, ShuffleBlockId} -import org.apache.spark.util.CompletionIterator +import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, ShuffleBlockFetcherIterator, + ShuffleBlockId} private[hash] object BlockStoreShuffleFetcher extends Logging { - def fetch[T]( + def fetchBlockStreams( shuffleId: Int, reduceId: Int, context: TaskContext, - serializer: Serializer) - : Iterator[T] = + blockManager: BlockManager, + mapOutputTracker: MapOutputTracker) + : Iterator[(BlockId, InputStream)] = { logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId)) - val blockManager = SparkEnv.get.blockManager val startTime = System.currentTimeMillis - val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, reduceId) + val statuses = mapOutputTracker.getServerStatuses(shuffleId, reduceId) logDebug("Fetching map output location for shuffle %d, reduce %d took %d ms".format( shuffleId, reduceId, System.currentTimeMillis - startTime)) @@ -53,12 +53,21 @@ private[hash] object BlockStoreShuffleFetcher extends Logging { (address, splits.map(s => (ShuffleBlockId(shuffleId, s._1, reduceId), s._2))) } - def unpackBlock(blockPair: (BlockId, Try[Iterator[Any]])) : Iterator[T] = { + val blockFetcherItr = new ShuffleBlockFetcherIterator( + context, + blockManager.shuffleClient, + blockManager, + blocksByAddress, + // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility + SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024) + + // Make sure that fetch failures are wrapped inside a FetchFailedException for the scheduler + blockFetcherItr.map { blockPair => val blockId = blockPair._1 val blockOption = blockPair._2 blockOption match { - case Success(block) => { - block.asInstanceOf[Iterator[T]] + case Success(inputStream) => { + (blockId, inputStream) } case Failure(e) => { blockId match { @@ -72,27 +81,5 @@ private[hash] object BlockStoreShuffleFetcher extends Logging { } } } - - val blockFetcherItr = new ShuffleBlockFetcherIterator( - context, - SparkEnv.get.blockManager.shuffleClient, - blockManager, - blocksByAddress, - serializer, - // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility - SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024) - val itr = blockFetcherItr.flatMap(unpackBlock) - - val completionIter = CompletionIterator[T, Iterator[T]](itr, { - context.taskMetrics.updateShuffleReadMetrics() - }) - - new InterruptibleIterator[T](context, completionIter) { - val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency() - override def next(): T = { - readMetrics.incRecordsRead(1) - delegate.next() - } - } } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala index 41bafabde05b9..d5c9880659dd3 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala @@ -17,16 +17,20 @@ package org.apache.spark.shuffle.hash -import org.apache.spark.{InterruptibleIterator, TaskContext} +import org.apache.spark.{InterruptibleIterator, MapOutputTracker, SparkEnv, TaskContext} import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader} +import org.apache.spark.storage.BlockManager +import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.ExternalSorter private[spark] class HashShuffleReader[K, C]( handle: BaseShuffleHandle[K, _, C], startPartition: Int, endPartition: Int, - context: TaskContext) + context: TaskContext, + blockManager: BlockManager = SparkEnv.get.blockManager, + mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker) extends ShuffleReader[K, C] { require(endPartition == startPartition + 1, @@ -36,20 +40,52 @@ private[spark] class HashShuffleReader[K, C]( /** Read the combined key-values for this reduce task */ override def read(): Iterator[Product2[K, C]] = { + val blockStreams = BlockStoreShuffleFetcher.fetchBlockStreams( + handle.shuffleId, startPartition, context, blockManager, mapOutputTracker) + + // Wrap the streams for compression based on configuration + val wrappedStreams = blockStreams.map { case (blockId, inputStream) => + blockManager.wrapForCompression(blockId, inputStream) + } + val ser = Serializer.getSerializer(dep.serializer) - val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, ser) + val serializerInstance = ser.newInstance() + + // Create a key/value iterator for each stream + val recordIter = wrappedStreams.flatMap { wrappedStream => + // Note: the asKeyValueIterator below wraps a key/value iterator inside of a + // NextIterator. The NextIterator makes sure that close() is called on the + // underlying InputStream when all records have been read. + serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator + } + + // Update the context task metrics for each record read. + val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency() + val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]]( + recordIter.map(record => { + readMetrics.incRecordsRead(1) + record + }), + context.taskMetrics().updateShuffleReadMetrics()) + + // An interruptible iterator must be used here in order to support task cancellation + val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter) val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) { if (dep.mapSideCombine) { - new InterruptibleIterator(context, dep.aggregator.get.combineCombinersByKey(iter, context)) + // We are reading values that are already combined + val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]] + dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context) } else { - new InterruptibleIterator(context, dep.aggregator.get.combineValuesByKey(iter, context)) + // We don't know the value type, but also don't care -- the dependency *should* + // have made sure its compatible w/ this aggregator, which will convert the value + // type to the combined type C + val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]] + dep.aggregator.get.combineValuesByKey(keyValuesIterator, context) } } else { require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!") - - // Convert the Product2s to pairs since this is what downstream RDDs currently expect - iter.asInstanceOf[Iterator[Product2[K, C]]].map(pair => (pair._1, pair._2)) + interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]] } // Sort the output if there is a sort ordering defined. diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index d0faab62c9e9e..e49e39679e940 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -17,23 +17,23 @@ package org.apache.spark.storage +import java.io.InputStream import java.util.concurrent.LinkedBlockingQueue import scala.collection.mutable.{ArrayBuffer, HashSet, Queue} import scala.util.{Failure, Try} import org.apache.spark.{Logging, TaskContext} -import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient} import org.apache.spark.network.buffer.ManagedBuffer -import org.apache.spark.serializer.{SerializerInstance, Serializer} -import org.apache.spark.util.{CompletionIterator, Utils} +import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient} +import org.apache.spark.util.Utils /** * An iterator that fetches multiple blocks. For local blocks, it fetches from the local block * manager. For remote blocks, it fetches them using the provided BlockTransferService. * - * This creates an iterator of (BlockID, values) tuples so the caller can handle blocks in a - * pipelined fashion as they are received. + * This creates an iterator of (BlockID, Try[InputStream]) tuples so the caller can handle blocks + * in a pipelined fashion as they are received. * * The implementation throttles the remote fetches to they don't exceed maxBytesInFlight to avoid * using too much memory. @@ -44,7 +44,6 @@ import org.apache.spark.util.{CompletionIterator, Utils} * @param blocksByAddress list of blocks to fetch grouped by the [[BlockManagerId]]. * For each block we also require the size (in bytes as a long field) in * order to throttle the memory usage. - * @param serializer serializer used to deserialize the data. * @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point. */ private[spark] @@ -53,9 +52,8 @@ final class ShuffleBlockFetcherIterator( shuffleClient: ShuffleClient, blockManager: BlockManager, blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], - serializer: Serializer, maxBytesInFlight: Long) - extends Iterator[(BlockId, Try[Iterator[Any]])] with Logging { + extends Iterator[(BlockId, Try[InputStream])] with Logging { import ShuffleBlockFetcherIterator._ @@ -83,7 +81,7 @@ final class ShuffleBlockFetcherIterator( /** * A queue to hold our results. This turns the asynchronous model provided by - * [[BlockTransferService]] into a synchronous model (iterator). + * [[org.apache.spark.network.BlockTransferService]] into a synchronous model (iterator). */ private[this] val results = new LinkedBlockingQueue[FetchResult] @@ -102,9 +100,7 @@ final class ShuffleBlockFetcherIterator( /** Current bytes in flight from our requests */ private[this] var bytesInFlight = 0L - private[this] val shuffleMetrics = context.taskMetrics.createShuffleReadMetricsForDependency() - - private[this] val serializerInstance: SerializerInstance = serializer.newInstance() + private[this] val shuffleMetrics = context.taskMetrics().createShuffleReadMetricsForDependency() /** * Whether the iterator is still active. If isZombie is true, the callback interface will no @@ -114,17 +110,23 @@ final class ShuffleBlockFetcherIterator( initialize() - /** - * Mark the iterator as zombie, and release all buffers that haven't been deserialized yet. - */ - private[this] def cleanup() { - isZombie = true + // Decrements the buffer reference count. + // The currentResult is set to null to prevent releasing the buffer again on cleanup() + private[storage] def releaseCurrentResultBuffer(): Unit = { // Release the current buffer if necessary currentResult match { case SuccessFetchResult(_, _, buf) => buf.release() case _ => } + currentResult = null + } + /** + * Mark the iterator as zombie, and release all buffers that haven't been deserialized yet. + */ + private[this] def cleanup() { + isZombie = true + releaseCurrentResultBuffer() // Release buffers in the results queue val iter = results.iterator() while (iter.hasNext) { @@ -272,7 +274,13 @@ final class ShuffleBlockFetcherIterator( override def hasNext: Boolean = numBlocksProcessed < numBlocksToFetch - override def next(): (BlockId, Try[Iterator[Any]]) = { + /** + * Fetches the next (BlockId, Try[InputStream]). If a task fails, the ManagedBuffers + * underlying each InputStream will be freed by the cleanup() method registered with the + * TaskCompletionListener. However, callers should close() these InputStreams + * as soon as they are no longer needed, in order to release memory as early as possible. + */ + override def next(): (BlockId, Try[InputStream]) = { numBlocksProcessed += 1 val startFetchWait = System.currentTimeMillis() currentResult = results.take() @@ -290,22 +298,15 @@ final class ShuffleBlockFetcherIterator( sendRequest(fetchRequests.dequeue()) } - val iteratorTry: Try[Iterator[Any]] = result match { + val iteratorTry: Try[InputStream] = result match { case FailureFetchResult(_, e) => Failure(e) case SuccessFetchResult(blockId, _, buf) => // There is a chance that createInputStream can fail (e.g. fetching a local file that does // not exist, SPARK-4085). In that case, we should propagate the right exception so // the scheduler gets a FetchFailedException. - Try(buf.createInputStream()).map { is0 => - val is = blockManager.wrapForCompression(blockId, is0) - val iter = serializerInstance.deserializeStream(is).asKeyValueIterator - CompletionIterator[Any, Iterator[Any]](iter, { - // Once the iterator is exhausted, release the buffer and set currentResult to null - // so we don't release it again in cleanup. - currentResult = null - buf.release() - }) + Try(buf.createInputStream()).map { inputStream => + new BufferReleasingInputStream(inputStream, this) } } @@ -313,6 +314,39 @@ final class ShuffleBlockFetcherIterator( } } +/** + * Helper class that ensures a ManagedBuffer is release upon InputStream.close() + */ +private class BufferReleasingInputStream( + private val delegate: InputStream, + private val iterator: ShuffleBlockFetcherIterator) + extends InputStream { + private[this] var closed = false + + override def read(): Int = delegate.read() + + override def close(): Unit = { + if (!closed) { + delegate.close() + iterator.releaseCurrentResultBuffer() + closed = true + } + } + + override def available(): Int = delegate.available() + + override def mark(readlimit: Int): Unit = delegate.mark(readlimit) + + override def skip(n: Long): Long = delegate.skip(n) + + override def markSupported(): Boolean = delegate.markSupported() + + override def read(b: Array[Byte]): Int = delegate.read(b) + + override def read(b: Array[Byte], off: Int, len: Int): Int = delegate.read(b, off, len) + + override def reset(): Unit = delegate.reset() +} private[storage] object ShuffleBlockFetcherIterator { diff --git a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala new file mode 100644 index 0000000000000..28ca68698e3dc --- /dev/null +++ b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.hash + +import java.io.{ByteArrayOutputStream, InputStream} +import java.nio.ByteBuffer + +import org.mockito.Matchers.{eq => meq, _} +import org.mockito.Mockito.{mock, when} +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer + +import org.apache.spark._ +import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} +import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.shuffle.BaseShuffleHandle +import org.apache.spark.storage.{BlockManager, BlockManagerId, ShuffleBlockId} + +/** + * Wrapper for a managed buffer that keeps track of how many times retain and release are called. + * + * We need to define this class ourselves instead of using a spy because the NioManagedBuffer class + * is final (final classes cannot be spied on). + */ +class RecordingManagedBuffer(underlyingBuffer: NioManagedBuffer) extends ManagedBuffer { + var callsToRetain = 0 + var callsToRelease = 0 + + override def size(): Long = underlyingBuffer.size() + override def nioByteBuffer(): ByteBuffer = underlyingBuffer.nioByteBuffer() + override def createInputStream(): InputStream = underlyingBuffer.createInputStream() + override def convertToNetty(): AnyRef = underlyingBuffer.convertToNetty() + + override def retain(): ManagedBuffer = { + callsToRetain += 1 + underlyingBuffer.retain() + } + override def release(): ManagedBuffer = { + callsToRelease += 1 + underlyingBuffer.release() + } +} + +class HashShuffleReaderSuite extends SparkFunSuite with LocalSparkContext { + + /** + * This test makes sure that, when data is read from a HashShuffleReader, the underlying + * ManagedBuffers that contain the data are eventually released. + */ + test("read() releases resources on completion") { + val testConf = new SparkConf(false) + // Create a SparkContext as a convenient way of setting SparkEnv (needed because some of the + // shuffle code calls SparkEnv.get()). + sc = new SparkContext("local", "test", testConf) + + val reduceId = 15 + val shuffleId = 22 + val numMaps = 6 + val keyValuePairsPerMap = 10 + val serializer = new JavaSerializer(testConf) + + // Make a mock BlockManager that will return RecordingManagedByteBuffers of data, so that we + // can ensure retain() and release() are properly called. + val blockManager = mock(classOf[BlockManager]) + + // Create a return function to use for the mocked wrapForCompression method that just returns + // the original input stream. + val dummyCompressionFunction = new Answer[InputStream] { + override def answer(invocation: InvocationOnMock): InputStream = + invocation.getArguments()(1).asInstanceOf[InputStream] + } + + // Create a buffer with some randomly generated key-value pairs to use as the shuffle data + // from each mappers (all mappers return the same shuffle data). + val byteOutputStream = new ByteArrayOutputStream() + val serializationStream = serializer.newInstance().serializeStream(byteOutputStream) + (0 until keyValuePairsPerMap).foreach { i => + serializationStream.writeKey(i) + serializationStream.writeValue(2*i) + } + + // Setup the mocked BlockManager to return RecordingManagedBuffers. + val localBlockManagerId = BlockManagerId("test-client", "test-client", 1) + when(blockManager.blockManagerId).thenReturn(localBlockManagerId) + val buffers = (0 until numMaps).map { mapId => + // Create a ManagedBuffer with the shuffle data. + val nioBuffer = new NioManagedBuffer(ByteBuffer.wrap(byteOutputStream.toByteArray)) + val managedBuffer = new RecordingManagedBuffer(nioBuffer) + + // Setup the blockManager mock so the buffer gets returned when the shuffle code tries to + // fetch shuffle data. + val shuffleBlockId = ShuffleBlockId(shuffleId, mapId, reduceId) + when(blockManager.getBlockData(shuffleBlockId)).thenReturn(managedBuffer) + when(blockManager.wrapForCompression(meq(shuffleBlockId), isA(classOf[InputStream]))) + .thenAnswer(dummyCompressionFunction) + + managedBuffer + } + + // Make a mocked MapOutputTracker for the shuffle reader to use to determine what + // shuffle data to read. + val mapOutputTracker = mock(classOf[MapOutputTracker]) + // Test a scenario where all data is local, just to avoid creating a bunch of additional mocks + // for the code to read data over the network. + val statuses: Array[(BlockManagerId, Long)] = + Array.fill(numMaps)((localBlockManagerId, byteOutputStream.size().toLong)) + when(mapOutputTracker.getServerStatuses(shuffleId, reduceId)).thenReturn(statuses) + + // Create a mocked shuffle handle to pass into HashShuffleReader. + val shuffleHandle = { + val dependency = mock(classOf[ShuffleDependency[Int, Int, Int]]) + when(dependency.serializer).thenReturn(Some(serializer)) + when(dependency.aggregator).thenReturn(None) + when(dependency.keyOrdering).thenReturn(None) + new BaseShuffleHandle(shuffleId, numMaps, dependency) + } + + val shuffleReader = new HashShuffleReader( + shuffleHandle, + reduceId, + reduceId + 1, + new TaskContextImpl(0, 0, 0, 0, null), + blockManager, + mapOutputTracker) + + assert(shuffleReader.read().length === keyValuePairsPerMap * numMaps) + + // Calling .length above will have exhausted the iterator; make sure that exhausting the + // iterator caused retain and release to be called on each buffer. + buffers.foreach { buffer => + assert(buffer.callsToRetain === 1) + assert(buffer.callsToRelease === 1) + } + } +} diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 2a7fe67ad8585..9ced4148d7206 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -17,23 +17,25 @@ package org.apache.spark.storage +import java.io.InputStream import java.util.concurrent.Semaphore -import scala.concurrent.future import scala.concurrent.ExecutionContext.Implicits.global +import scala.concurrent.future import org.mockito.Matchers.{any, eq => meq} import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer +import org.scalatest.PrivateMethodTester -import org.apache.spark.{SparkConf, SparkFunSuite, TaskContextImpl} +import org.apache.spark.{SparkFunSuite, TaskContextImpl} import org.apache.spark.network._ import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.shuffle.BlockFetchingListener -import org.apache.spark.serializer.TestSerializer -class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { + +class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodTester { // Some of the tests are quite tricky because we are testing the cleanup behavior // in the presence of faults. @@ -57,7 +59,12 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { transfer } - private val conf = new SparkConf + // Create a mock managed buffer for testing + def createMockManagedBuffer(): ManagedBuffer = { + val mockManagedBuffer = mock(classOf[ManagedBuffer]) + when(mockManagedBuffer.createInputStream()).thenReturn(mock(classOf[InputStream])) + mockManagedBuffer + } test("successful 3 local reads + 2 remote reads") { val blockManager = mock(classOf[BlockManager]) @@ -66,9 +73,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { // Make sure blockManager.getBlockData would return the blocks val localBlocks = Map[BlockId, ManagedBuffer]( - ShuffleBlockId(0, 0, 0) -> mock(classOf[ManagedBuffer]), - ShuffleBlockId(0, 1, 0) -> mock(classOf[ManagedBuffer]), - ShuffleBlockId(0, 2, 0) -> mock(classOf[ManagedBuffer])) + ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer(), + ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer(), + ShuffleBlockId(0, 2, 0) -> createMockManagedBuffer()) localBlocks.foreach { case (blockId, buf) => doReturn(buf).when(blockManager).getBlockData(meq(blockId)) } @@ -76,9 +83,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { // Make sure remote blocks would return val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) val remoteBlocks = Map[BlockId, ManagedBuffer]( - ShuffleBlockId(0, 3, 0) -> mock(classOf[ManagedBuffer]), - ShuffleBlockId(0, 4, 0) -> mock(classOf[ManagedBuffer]) - ) + ShuffleBlockId(0, 3, 0) -> createMockManagedBuffer(), + ShuffleBlockId(0, 4, 0) -> createMockManagedBuffer()) val transfer = createMockTransfer(remoteBlocks) @@ -92,7 +98,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { transfer, blockManager, blocksByAddress, - new TestSerializer, 48 * 1024 * 1024) // 3 local blocks fetched in initialization @@ -100,15 +105,24 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { for (i <- 0 until 5) { assert(iterator.hasNext, s"iterator should have 5 elements but actually has $i elements") - val (blockId, subIterator) = iterator.next() - assert(subIterator.isSuccess, + val (blockId, inputStream) = iterator.next() + assert(inputStream.isSuccess, s"iterator should have 5 elements defined but actually has $i elements") - // Make sure we release the buffer once the iterator is exhausted. + // Make sure we release buffers when a wrapped input stream is closed. val mockBuf = localBlocks.getOrElse(blockId, remoteBlocks(blockId)) + // Note: ShuffleBlockFetcherIterator wraps input streams in a BufferReleasingInputStream + val wrappedInputStream = inputStream.get.asInstanceOf[BufferReleasingInputStream] verify(mockBuf, times(0)).release() - subIterator.get.foreach(_ => Unit) // exhaust the iterator + val delegateAccess = PrivateMethod[InputStream]('delegate) + + verify(wrappedInputStream.invokePrivate(delegateAccess()), times(0)).close() + wrappedInputStream.close() + verify(mockBuf, times(1)).release() + verify(wrappedInputStream.invokePrivate(delegateAccess()), times(1)).close() + wrappedInputStream.close() // close should be idempotent verify(mockBuf, times(1)).release() + verify(wrappedInputStream.invokePrivate(delegateAccess()), times(1)).close() } // 3 local blocks, and 2 remote blocks @@ -125,10 +139,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { // Make sure remote blocks would return val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) val blocks = Map[BlockId, ManagedBuffer]( - ShuffleBlockId(0, 0, 0) -> mock(classOf[ManagedBuffer]), - ShuffleBlockId(0, 1, 0) -> mock(classOf[ManagedBuffer]), - ShuffleBlockId(0, 2, 0) -> mock(classOf[ManagedBuffer]) - ) + ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer(), + ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer(), + ShuffleBlockId(0, 2, 0) -> createMockManagedBuffer()) // Semaphore to coordinate event sequence in two different threads. val sem = new Semaphore(0) @@ -159,11 +172,10 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { transfer, blockManager, blocksByAddress, - new TestSerializer, 48 * 1024 * 1024) - // Exhaust the first block, and then it should be released. - iterator.next()._2.get.foreach(_ => Unit) + verify(blocks(ShuffleBlockId(0, 0, 0)), times(0)).release() + iterator.next()._2.get.close() // close() first block's input stream verify(blocks(ShuffleBlockId(0, 0, 0)), times(1)).release() // Get the 2nd block but do not exhaust the iterator @@ -222,7 +234,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { transfer, blockManager, blocksByAddress, - new TestSerializer, 48 * 1024 * 1024) // Continue only after the mock calls onBlockFetchFailure