From 5186da0697d5a1efe4229b7b0a224979ce7f2bc7 Mon Sep 17 00:00:00 2001 From: Kay Ousterhout Date: Tue, 23 Jun 2015 11:34:58 -0700 Subject: [PATCH 1/2] Revert "Add test to ensure HashShuffleReader is freeing resources" This reverts commit f98a1b9503ed8fb4c5fc3e9033a744c254237c45. --- .../hash/BlockStoreShuffleFetcher.scala | 3 +- .../shuffle/hash/HashShuffleReader.scala | 8 +- .../hash/HashShuffleManagerSuite.scala | 114 +----------------- 3 files changed, 10 insertions(+), 115 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala index aefb2f5685537..0635b98742096 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala @@ -26,8 +26,7 @@ import org.apache.spark._ import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockFetcherIterator, ShuffleBlockId} -private[hash] class BlockStoreShuffleFetcher extends Logging { - +private[hash] object BlockStoreShuffleFetcher extends Logging { def fetchBlockStreams( shuffleId: Int, reduceId: Int, diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala index b868f32f5cce1..ca6eddf8d5c12 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala @@ -17,7 +17,6 @@ package org.apache.spark.shuffle.hash -import org.apache.spark.storage.BlockManager import org.apache.spark.{InterruptibleIterator, SparkEnv, TaskContext} import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader} @@ -28,19 +27,18 @@ private[spark] class HashShuffleReader[K, C]( handle: BaseShuffleHandle[K, _, C], startPartition: Int, endPartition: Int, - context: TaskContext, - blockManager: BlockManager = SparkEnv.get.blockManager, - blockStoreShuffleFetcher: BlockStoreShuffleFetcher = new BlockStoreShuffleFetcher) + context: TaskContext) extends ShuffleReader[K, C] { require(endPartition == startPartition + 1, "Hash shuffle currently only supports fetching one partition") private val dep = handle.dependency + private val blockManager = SparkEnv.get.blockManager /** Read the combined key-values for this reduce task */ override def read(): Iterator[Product2[K, C]] = { - val blockStreams = blockStoreShuffleFetcher.fetchBlockStreams( + val blockStreams = BlockStoreShuffleFetcher.fetchBlockStreams( handle.shuffleId, startPartition, context) // Wrap the streams for compression based on configuration diff --git a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala index 53b2b89a5e641..491dc3659e184 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala @@ -17,22 +17,16 @@ package org.apache.spark.shuffle.hash -import java.io._ -import java.nio.ByteBuffer +import java.io.{File, FileWriter} import scala.language.reflectiveCalls -import org.mockito.Matchers.any -import org.mockito.Mockito._ -import org.mockito.invocation.InvocationOnMock -import org.mockito.stubbing.Answer - -import org.apache.spark._ -import org.apache.spark.executor.{ShuffleReadMetrics, TaskMetrics, ShuffleWriteMetrics} +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv, SparkFunSuite} +import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} -import org.apache.spark.serializer._ -import org.apache.spark.shuffle.{BaseShuffleHandle, FileShuffleBlockResolver} -import org.apache.spark.storage.{BlockId, BlockManager, ShuffleBlockId, FileSegment} +import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.shuffle.FileShuffleBlockResolver +import org.apache.spark.storage.{ShuffleBlockId, FileSegment} class HashShuffleManagerSuite extends SparkFunSuite with LocalSparkContext { private val testConf = new SparkConf(false) @@ -113,100 +107,4 @@ class HashShuffleManagerSuite extends SparkFunSuite with LocalSparkContext { for (i <- 0 until numBytes) writer.write(i) writer.close() } - - test("HashShuffleReader.read() releases resources and tracks metrics") { - val shuffleId = 1 - val numMaps = 2 - val numKeyValuePairs = 10 - - val mockContext = mock(classOf[TaskContext]) - - val mockTaskMetrics = mock(classOf[TaskMetrics]) - val mockReadMetrics = mock(classOf[ShuffleReadMetrics]) - when(mockTaskMetrics.createShuffleReadMetricsForDependency()).thenReturn(mockReadMetrics) - when(mockContext.taskMetrics()).thenReturn(mockTaskMetrics) - - val mockShuffleFetcher = mock(classOf[BlockStoreShuffleFetcher]) - - val mockDep = mock(classOf[ShuffleDependency[_, _, _]]) - when(mockDep.keyOrdering).thenReturn(None) - when(mockDep.aggregator).thenReturn(None) - when(mockDep.serializer).thenReturn(Some(new Serializer { - override def newInstance(): SerializerInstance = new SerializerInstance { - - override def deserializeStream(s: InputStream): DeserializationStream = - new DeserializationStream { - override def readObject[T: ClassManifest](): T = null.asInstanceOf[T] - - override def close(): Unit = s.close() - - private val values = { - for (i <- 0 to numKeyValuePairs * 2) yield i - }.iterator - - private def getValueOrEOF(): Int = { - if (values.hasNext) { - values.next() - } else { - throw new EOFException("End of the file: mock deserializeStream") - } - } - - // NOTE: the readKey and readValue methods are called by asKeyValueIterator() - // which is wrapped in a NextIterator - override def readKey[T: ClassManifest](): T = getValueOrEOF().asInstanceOf[T] - - override def readValue[T: ClassManifest](): T = getValueOrEOF().asInstanceOf[T] - } - - override def deserialize[T: ClassManifest](bytes: ByteBuffer, loader: ClassLoader): T = - null.asInstanceOf[T] - - override def serialize[T: ClassManifest](t: T): ByteBuffer = ByteBuffer.allocate(0) - - override def serializeStream(s: OutputStream): SerializationStream = - null.asInstanceOf[SerializationStream] - - override def deserialize[T: ClassManifest](bytes: ByteBuffer): T = null.asInstanceOf[T] - } - })) - - val mockBlockManager = { - // Create a block manager that isn't configured for compression, just returns input stream - val blockManager = mock(classOf[BlockManager]) - when(blockManager.wrapForCompression(any[BlockId](), any[InputStream]())) - .thenAnswer(new Answer[InputStream] { - override def answer(invocation: InvocationOnMock): InputStream = { - val blockId = invocation.getArguments()(0).asInstanceOf[BlockId] - val inputStream = invocation.getArguments()(1).asInstanceOf[InputStream] - inputStream - } - }) - blockManager - } - - val mockInputStream = mock(classOf[InputStream]) - when(mockShuffleFetcher.fetchBlockStreams(any[Int](), any[Int](), any[TaskContext]())) - .thenReturn(Iterator.single((mock(classOf[BlockId]), mockInputStream))) - - val shuffleHandle = new BaseShuffleHandle(shuffleId, numMaps, mockDep) - - val reader = new HashShuffleReader(shuffleHandle, 0, 1, - mockContext, mockBlockManager, mockShuffleFetcher) - - val values = reader.read() - // Verify that we're reading the correct values - var numValuesRead = 0 - for (((key: Int, value: Int), i) <- values.zipWithIndex) { - assert(key == i * 2) - assert(value == i * 2 + 1) - numValuesRead += 1 - } - // Verify that we read the correct number of values - assert(numKeyValuePairs == numValuesRead) - // Verify that our input stream was closed - verify(mockInputStream, times(1)).close() - // Verify that we collected metrics for each key/value pair - verify(mockReadMetrics, times(numKeyValuePairs)).incRecordsRead(1) - } } From 290f1eb356024fb58a209e9fc6c8800bfc0e6688 Mon Sep 17 00:00:00 2001 From: Kay Ousterhout Date: Tue, 23 Jun 2015 15:41:06 -0700 Subject: [PATCH 2/2] Added test for HashShuffleReader.read() --- .../hash/BlockStoreShuffleFetcher.scala | 14 +- .../shuffle/hash/HashShuffleReader.scala | 10 +- .../shuffle/hash/HashShuffleReaderSuite.scala | 150 ++++++++++++++++++ 3 files changed, 164 insertions(+), 10 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala index 0635b98742096..9d8e7e9f03aea 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala @@ -20,24 +20,26 @@ package org.apache.spark.shuffle.hash import java.io.InputStream import scala.collection.mutable.{ArrayBuffer, HashMap} -import scala.util.{Failure, Success, Try} +import scala.util.{Failure, Success} import org.apache.spark._ import org.apache.spark.shuffle.FetchFailedException -import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockFetcherIterator, ShuffleBlockId} +import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, ShuffleBlockFetcherIterator, + ShuffleBlockId} private[hash] object BlockStoreShuffleFetcher extends Logging { def fetchBlockStreams( shuffleId: Int, reduceId: Int, - context: TaskContext) + context: TaskContext, + blockManager: BlockManager, + mapOutputTracker: MapOutputTracker) : Iterator[(BlockId, InputStream)] = { logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId)) - val blockManager = SparkEnv.get.blockManager val startTime = System.currentTimeMillis - val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, reduceId) + val statuses = mapOutputTracker.getServerStatuses(shuffleId, reduceId) logDebug("Fetching map output location for shuffle %d, reduce %d took %d ms".format( shuffleId, reduceId, System.currentTimeMillis - startTime)) @@ -53,7 +55,7 @@ private[hash] object BlockStoreShuffleFetcher extends Logging { val blockFetcherItr = new ShuffleBlockFetcherIterator( context, - SparkEnv.get.blockManager.shuffleClient, + blockManager.shuffleClient, blockManager, blocksByAddress, // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala index ca6eddf8d5c12..d5c9880659dd3 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala @@ -17,9 +17,10 @@ package org.apache.spark.shuffle.hash -import org.apache.spark.{InterruptibleIterator, SparkEnv, TaskContext} +import org.apache.spark.{InterruptibleIterator, MapOutputTracker, SparkEnv, TaskContext} import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader} +import org.apache.spark.storage.BlockManager import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.ExternalSorter @@ -27,19 +28,20 @@ private[spark] class HashShuffleReader[K, C]( handle: BaseShuffleHandle[K, _, C], startPartition: Int, endPartition: Int, - context: TaskContext) + context: TaskContext, + blockManager: BlockManager = SparkEnv.get.blockManager, + mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker) extends ShuffleReader[K, C] { require(endPartition == startPartition + 1, "Hash shuffle currently only supports fetching one partition") private val dep = handle.dependency - private val blockManager = SparkEnv.get.blockManager /** Read the combined key-values for this reduce task */ override def read(): Iterator[Product2[K, C]] = { val blockStreams = BlockStoreShuffleFetcher.fetchBlockStreams( - handle.shuffleId, startPartition, context) + handle.shuffleId, startPartition, context, blockManager, mapOutputTracker) // Wrap the streams for compression based on configuration val wrappedStreams = blockStreams.map { case (blockId, inputStream) => diff --git a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala new file mode 100644 index 0000000000000..0add85c6377dc --- /dev/null +++ b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.hash + +import java.io.{ByteArrayOutputStream, InputStream} +import java.nio.ByteBuffer + +import org.mockito.Matchers.{eq => meq, _} +import org.mockito.Mockito.{mock, when} +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer + +import org.apache.spark._ +import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} +import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.shuffle.BaseShuffleHandle +import org.apache.spark.storage.{BlockManager, BlockManagerId, ShuffleBlockId} + +/** + * Wrapper for a managed buffer that keeps track of how many times retain and release are called. + * + * We need to define this class ourselves instead of using a spy because the NioManagedBuffer class + * is final (final classes cannot be spied on). + */ +class RecordingManagedBuffer(underlyingBuffer: NioManagedBuffer) extends ManagedBuffer { + var callsToRetain = 0 + var callsToRelease = 0 + + override def size() = underlyingBuffer.size() + override def nioByteBuffer() = underlyingBuffer.nioByteBuffer() + override def createInputStream() = underlyingBuffer.createInputStream() + override def convertToNetty() = underlyingBuffer.convertToNetty() + + override def retain(): ManagedBuffer = { + callsToRetain += 1 + underlyingBuffer.retain() + } + override def release(): ManagedBuffer = { + callsToRelease += 1 + underlyingBuffer.release() + } +} + +class HashShuffleReaderSuite extends SparkFunSuite with LocalSparkContext { + + /** + * This test makes sure that, when data is read from a HashShuffleReader, the underlying + * ManagedBuffers that contain the data are eventually released. + */ + test("read() releases resources on completion") { + val testConf = new SparkConf(false) + // Create a SparkContext as a convenient way of setting SparkEnv (needed because some of the + // shuffle code calls SparkEnv.get()). + sc = new SparkContext("local", "test", testConf) + + val reduceId = 15 + val shuffleId = 22 + val numMaps = 6 + val keyValuePairsPerMap = 10 + val serializer = new JavaSerializer(testConf) + + // Make a mock BlockManager that will return RecordingManagedByteBuffers of data, so that we + // can ensure retain() and release() are properly called. + val blockManager = mock(classOf[BlockManager]) + + // Create a return function to use for the mocked wrapForCompression method that just returns + // the original input stream. + val dummyCompressionFunction = new Answer[InputStream] { + override def answer(invocation: InvocationOnMock) = + invocation.getArguments()(1).asInstanceOf[InputStream] + } + + // Create a buffer with some randomly generated key-value pairs to use as the shuffle data + // from each mappers (all mappers return the same shuffle data). + val byteOutputStream = new ByteArrayOutputStream() + val serializationStream = serializer.newInstance().serializeStream(byteOutputStream) + (0 until keyValuePairsPerMap).foreach { i => + serializationStream.writeKey(i) + serializationStream.writeValue(2*i) + } + + // Setup the mocked BlockManager to return RecordingManagedBuffers. + val localBlockManagerId = BlockManagerId("test-client", "test-client", 1) + when(blockManager.blockManagerId).thenReturn(localBlockManagerId) + val buffers = (0 until numMaps).map { mapId => + // Create a ManagedBuffer with the shuffle data. + val nioBuffer = new NioManagedBuffer(ByteBuffer.wrap(byteOutputStream.toByteArray)) + val managedBuffer = new RecordingManagedBuffer(nioBuffer) + + // Setup the blockManager mock so the buffer gets returned when the shuffle code tries to + // fetch shuffle data. + val shuffleBlockId = ShuffleBlockId(shuffleId, mapId, reduceId) + when(blockManager.getBlockData(shuffleBlockId)).thenReturn(managedBuffer) + when(blockManager.wrapForCompression(meq(shuffleBlockId), isA(classOf[InputStream]))) + .thenAnswer(dummyCompressionFunction) + + managedBuffer + } + + // Make a mocked MapOutputTracker for the shuffle reader to use to determine what + // shuffle data to read. + val mapOutputTracker = mock(classOf[MapOutputTracker]) + // Test a scenario where all data is local, just to avoid creating a bunch of additional mocks + // for the code to read data over the network. + val statuses: Array[(BlockManagerId, Long)] = + Array.fill(numMaps)((localBlockManagerId, byteOutputStream.size())) + when(mapOutputTracker.getServerStatuses(shuffleId, reduceId)).thenReturn(statuses) + + // Create a mocked shuffle handle to pass into HashShuffleReader. + val shuffleHandle = { + val dependency = mock(classOf[ShuffleDependency[Int, Int, Int]]) + when(dependency.serializer).thenReturn(Some(serializer)) + when(dependency.aggregator).thenReturn(None) + when(dependency.keyOrdering).thenReturn(None) + new BaseShuffleHandle(shuffleId, numMaps, dependency) + } + + val shuffleReader = new HashShuffleReader( + shuffleHandle, + reduceId, + reduceId + 1, + new TaskContextImpl(0, 0, 0, 0, null), + blockManager, + mapOutputTracker) + + assert(shuffleReader.read().length === keyValuePairsPerMap * numMaps) + + // Calling .length above will have exhausted the iterator; make sure that exhausting the + // iterator caused retain and release to be called on each buffer. + buffers.foreach { buffer => + assert(buffer.callsToRetain === 1) + assert(buffer.callsToRelease === 1) + } + } +}