Skip to content

Commit

Permalink
Merge pull request #1 from kayousterhout/massie_shuffle-api-cleanup
Browse files Browse the repository at this point in the history
Proposal for different unit test
  • Loading branch information
massie committed Jun 23, 2015
2 parents f98a1b9 + 290f1eb commit d0a1b39
Show file tree
Hide file tree
Showing 4 changed files with 170 additions and 121 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,25 +20,26 @@ package org.apache.spark.shuffle.hash
import java.io.InputStream

import scala.collection.mutable.{ArrayBuffer, HashMap}
import scala.util.{Failure, Success, Try}
import scala.util.{Failure, Success}

import org.apache.spark._
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockFetcherIterator, ShuffleBlockId}

private[hash] class BlockStoreShuffleFetcher extends Logging {
import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, ShuffleBlockFetcherIterator,
ShuffleBlockId}

private[hash] object BlockStoreShuffleFetcher extends Logging {
def fetchBlockStreams(
shuffleId: Int,
reduceId: Int,
context: TaskContext)
context: TaskContext,
blockManager: BlockManager,
mapOutputTracker: MapOutputTracker)
: Iterator[(BlockId, InputStream)] =
{
logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId))
val blockManager = SparkEnv.get.blockManager

val startTime = System.currentTimeMillis
val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, reduceId)
val statuses = mapOutputTracker.getServerStatuses(shuffleId, reduceId)
logDebug("Fetching map output location for shuffle %d, reduce %d took %d ms".format(
shuffleId, reduceId, System.currentTimeMillis - startTime))

Expand All @@ -54,7 +55,7 @@ private[hash] class BlockStoreShuffleFetcher extends Logging {

val blockFetcherItr = new ShuffleBlockFetcherIterator(
context,
SparkEnv.get.blockManager.shuffleClient,
blockManager.shuffleClient,
blockManager,
blocksByAddress,
// Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@

package org.apache.spark.shuffle.hash

import org.apache.spark.storage.BlockManager
import org.apache.spark.{InterruptibleIterator, SparkEnv, TaskContext}
import org.apache.spark.{InterruptibleIterator, MapOutputTracker, SparkEnv, TaskContext}
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader}
import org.apache.spark.storage.BlockManager
import org.apache.spark.util.CompletionIterator
import org.apache.spark.util.collection.ExternalSorter

Expand All @@ -30,7 +30,7 @@ private[spark] class HashShuffleReader[K, C](
endPartition: Int,
context: TaskContext,
blockManager: BlockManager = SparkEnv.get.blockManager,
blockStoreShuffleFetcher: BlockStoreShuffleFetcher = new BlockStoreShuffleFetcher)
mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker)
extends ShuffleReader[K, C]
{
require(endPartition == startPartition + 1,
Expand All @@ -40,8 +40,8 @@ private[spark] class HashShuffleReader[K, C](

/** Read the combined key-values for this reduce task */
override def read(): Iterator[Product2[K, C]] = {
val blockStreams = blockStoreShuffleFetcher.fetchBlockStreams(
handle.shuffleId, startPartition, context)
val blockStreams = BlockStoreShuffleFetcher.fetchBlockStreams(
handle.shuffleId, startPartition, context, blockManager, mapOutputTracker)

// Wrap the streams for compression based on configuration
val wrappedStreams = blockStreams.map { case (blockId, inputStream) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,16 @@

package org.apache.spark.shuffle.hash

import java.io._
import java.nio.ByteBuffer
import java.io.{File, FileWriter}

import scala.language.reflectiveCalls

import org.mockito.Matchers.any
import org.mockito.Mockito._
import org.mockito.invocation.InvocationOnMock
import org.mockito.stubbing.Answer

import org.apache.spark._
import org.apache.spark.executor.{ShuffleReadMetrics, TaskMetrics, ShuffleWriteMetrics}
import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv, SparkFunSuite}
import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
import org.apache.spark.serializer._
import org.apache.spark.shuffle.{BaseShuffleHandle, FileShuffleBlockResolver}
import org.apache.spark.storage.{BlockId, BlockManager, ShuffleBlockId, FileSegment}
import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.shuffle.FileShuffleBlockResolver
import org.apache.spark.storage.{ShuffleBlockId, FileSegment}

class HashShuffleManagerSuite extends SparkFunSuite with LocalSparkContext {
private val testConf = new SparkConf(false)
Expand Down Expand Up @@ -113,100 +107,4 @@ class HashShuffleManagerSuite extends SparkFunSuite with LocalSparkContext {
for (i <- 0 until numBytes) writer.write(i)
writer.close()
}

test("HashShuffleReader.read() releases resources and tracks metrics") {
val shuffleId = 1
val numMaps = 2
val numKeyValuePairs = 10

val mockContext = mock(classOf[TaskContext])

val mockTaskMetrics = mock(classOf[TaskMetrics])
val mockReadMetrics = mock(classOf[ShuffleReadMetrics])
when(mockTaskMetrics.createShuffleReadMetricsForDependency()).thenReturn(mockReadMetrics)
when(mockContext.taskMetrics()).thenReturn(mockTaskMetrics)

val mockShuffleFetcher = mock(classOf[BlockStoreShuffleFetcher])

val mockDep = mock(classOf[ShuffleDependency[_, _, _]])
when(mockDep.keyOrdering).thenReturn(None)
when(mockDep.aggregator).thenReturn(None)
when(mockDep.serializer).thenReturn(Some(new Serializer {
override def newInstance(): SerializerInstance = new SerializerInstance {

override def deserializeStream(s: InputStream): DeserializationStream =
new DeserializationStream {
override def readObject[T: ClassManifest](): T = null.asInstanceOf[T]

override def close(): Unit = s.close()

private val values = {
for (i <- 0 to numKeyValuePairs * 2) yield i
}.iterator

private def getValueOrEOF(): Int = {
if (values.hasNext) {
values.next()
} else {
throw new EOFException("End of the file: mock deserializeStream")
}
}

// NOTE: the readKey and readValue methods are called by asKeyValueIterator()
// which is wrapped in a NextIterator
override def readKey[T: ClassManifest](): T = getValueOrEOF().asInstanceOf[T]

override def readValue[T: ClassManifest](): T = getValueOrEOF().asInstanceOf[T]
}

override def deserialize[T: ClassManifest](bytes: ByteBuffer, loader: ClassLoader): T =
null.asInstanceOf[T]

override def serialize[T: ClassManifest](t: T): ByteBuffer = ByteBuffer.allocate(0)

override def serializeStream(s: OutputStream): SerializationStream =
null.asInstanceOf[SerializationStream]

override def deserialize[T: ClassManifest](bytes: ByteBuffer): T = null.asInstanceOf[T]
}
}))

val mockBlockManager = {
// Create a block manager that isn't configured for compression, just returns input stream
val blockManager = mock(classOf[BlockManager])
when(blockManager.wrapForCompression(any[BlockId](), any[InputStream]()))
.thenAnswer(new Answer[InputStream] {
override def answer(invocation: InvocationOnMock): InputStream = {
val blockId = invocation.getArguments()(0).asInstanceOf[BlockId]
val inputStream = invocation.getArguments()(1).asInstanceOf[InputStream]
inputStream
}
})
blockManager
}

val mockInputStream = mock(classOf[InputStream])
when(mockShuffleFetcher.fetchBlockStreams(any[Int](), any[Int](), any[TaskContext]()))
.thenReturn(Iterator.single((mock(classOf[BlockId]), mockInputStream)))

val shuffleHandle = new BaseShuffleHandle(shuffleId, numMaps, mockDep)

val reader = new HashShuffleReader(shuffleHandle, 0, 1,
mockContext, mockBlockManager, mockShuffleFetcher)

val values = reader.read()
// Verify that we're reading the correct values
var numValuesRead = 0
for (((key: Int, value: Int), i) <- values.zipWithIndex) {
assert(key == i * 2)
assert(value == i * 2 + 1)
numValuesRead += 1
}
// Verify that we read the correct number of values
assert(numKeyValuePairs == numValuesRead)
// Verify that our input stream was closed
verify(mockInputStream, times(1)).close()
// Verify that we collected metrics for each key/value pair
verify(mockReadMetrics, times(numKeyValuePairs)).incRecordsRead(1)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.shuffle.hash

import java.io.{ByteArrayOutputStream, InputStream}
import java.nio.ByteBuffer

import org.mockito.Matchers.{eq => meq, _}
import org.mockito.Mockito.{mock, when}
import org.mockito.invocation.InvocationOnMock
import org.mockito.stubbing.Answer

import org.apache.spark._
import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.shuffle.BaseShuffleHandle
import org.apache.spark.storage.{BlockManager, BlockManagerId, ShuffleBlockId}

/**
* Wrapper for a managed buffer that keeps track of how many times retain and release are called.
*
* We need to define this class ourselves instead of using a spy because the NioManagedBuffer class
* is final (final classes cannot be spied on).
*/
class RecordingManagedBuffer(underlyingBuffer: NioManagedBuffer) extends ManagedBuffer {
var callsToRetain = 0
var callsToRelease = 0

override def size() = underlyingBuffer.size()
override def nioByteBuffer() = underlyingBuffer.nioByteBuffer()
override def createInputStream() = underlyingBuffer.createInputStream()
override def convertToNetty() = underlyingBuffer.convertToNetty()

override def retain(): ManagedBuffer = {
callsToRetain += 1
underlyingBuffer.retain()
}
override def release(): ManagedBuffer = {
callsToRelease += 1
underlyingBuffer.release()
}
}

class HashShuffleReaderSuite extends SparkFunSuite with LocalSparkContext {

/**
* This test makes sure that, when data is read from a HashShuffleReader, the underlying
* ManagedBuffers that contain the data are eventually released.
*/
test("read() releases resources on completion") {
val testConf = new SparkConf(false)
// Create a SparkContext as a convenient way of setting SparkEnv (needed because some of the
// shuffle code calls SparkEnv.get()).
sc = new SparkContext("local", "test", testConf)

val reduceId = 15
val shuffleId = 22
val numMaps = 6
val keyValuePairsPerMap = 10
val serializer = new JavaSerializer(testConf)

// Make a mock BlockManager that will return RecordingManagedByteBuffers of data, so that we
// can ensure retain() and release() are properly called.
val blockManager = mock(classOf[BlockManager])

// Create a return function to use for the mocked wrapForCompression method that just returns
// the original input stream.
val dummyCompressionFunction = new Answer[InputStream] {
override def answer(invocation: InvocationOnMock) =
invocation.getArguments()(1).asInstanceOf[InputStream]
}

// Create a buffer with some randomly generated key-value pairs to use as the shuffle data
// from each mappers (all mappers return the same shuffle data).
val byteOutputStream = new ByteArrayOutputStream()
val serializationStream = serializer.newInstance().serializeStream(byteOutputStream)
(0 until keyValuePairsPerMap).foreach { i =>
serializationStream.writeKey(i)
serializationStream.writeValue(2*i)
}

// Setup the mocked BlockManager to return RecordingManagedBuffers.
val localBlockManagerId = BlockManagerId("test-client", "test-client", 1)
when(blockManager.blockManagerId).thenReturn(localBlockManagerId)
val buffers = (0 until numMaps).map { mapId =>
// Create a ManagedBuffer with the shuffle data.
val nioBuffer = new NioManagedBuffer(ByteBuffer.wrap(byteOutputStream.toByteArray))
val managedBuffer = new RecordingManagedBuffer(nioBuffer)

// Setup the blockManager mock so the buffer gets returned when the shuffle code tries to
// fetch shuffle data.
val shuffleBlockId = ShuffleBlockId(shuffleId, mapId, reduceId)
when(blockManager.getBlockData(shuffleBlockId)).thenReturn(managedBuffer)
when(blockManager.wrapForCompression(meq(shuffleBlockId), isA(classOf[InputStream])))
.thenAnswer(dummyCompressionFunction)

managedBuffer
}

// Make a mocked MapOutputTracker for the shuffle reader to use to determine what
// shuffle data to read.
val mapOutputTracker = mock(classOf[MapOutputTracker])
// Test a scenario where all data is local, just to avoid creating a bunch of additional mocks
// for the code to read data over the network.
val statuses: Array[(BlockManagerId, Long)] =
Array.fill(numMaps)((localBlockManagerId, byteOutputStream.size()))
when(mapOutputTracker.getServerStatuses(shuffleId, reduceId)).thenReturn(statuses)

// Create a mocked shuffle handle to pass into HashShuffleReader.
val shuffleHandle = {
val dependency = mock(classOf[ShuffleDependency[Int, Int, Int]])
when(dependency.serializer).thenReturn(Some(serializer))
when(dependency.aggregator).thenReturn(None)
when(dependency.keyOrdering).thenReturn(None)
new BaseShuffleHandle(shuffleId, numMaps, dependency)
}

val shuffleReader = new HashShuffleReader(
shuffleHandle,
reduceId,
reduceId + 1,
new TaskContextImpl(0, 0, 0, 0, null),
blockManager,
mapOutputTracker)

assert(shuffleReader.read().length === keyValuePairsPerMap * numMaps)

// Calling .length above will have exhausted the iterator; make sure that exhausting the
// iterator caused retain and release to be called on each buffer.
buffers.foreach { buffer =>
assert(buffer.callsToRetain === 1)
assert(buffer.callsToRelease === 1)
}
}
}

0 comments on commit d0a1b39

Please sign in to comment.