Skip to content

Commit

Permalink
[SPARK-7884] Move block deserialization from BlockStoreShuffleFetcher…
Browse files Browse the repository at this point in the history
… to ShuffleReader

This commit updates the shuffle read path to enable ShuffleReader implementations more control over the deserialization process.

The BlockStoreShuffleFetcher.fetch() method has been renamed to BlockStoreShuffleFetcher.fetchBlockStreams(). Previously, this method returned a record iterator; now, it returns an iterator of (BlockId, InputStream). Deserialization of records is now handled in the ShuffleReader.read() method.

This change creates a cleaner separation of concerns and allows implementations of ShuffleReader more flexibility in how records are retrieved.

Author: Matt Massie <massie@cs.berkeley.edu>
Author: Kay Ousterhout <kayousterhout@gmail.com>

Closes apache#6423 from massie/shuffle-api-cleanup and squashes the following commits:

8b0632c [Matt Massie] Minor Scala style fixes
d0a1b39 [Matt Massie] Merge pull request #1 from kayousterhout/massie_shuffle-api-cleanup
290f1eb [Kay Ousterhout] Added test for HashShuffleReader.read()
5186da0 [Kay Ousterhout] Revert "Add test to ensure HashShuffleReader is freeing resources"
f98a1b9 [Matt Massie] Add test to ensure HashShuffleReader is freeing resources
a011bfa [Matt Massie] Use PrivateMethodTester on check that delegate stream is closed
4ea1712 [Matt Massie] Small code cleanup for readability
7429a98 [Matt Massie] Update tests to check that BufferReleasingStream is closing delegate InputStream
f458489 [Matt Massie] Remove unnecessary map() on return Iterator
4abb855 [Matt Massie] Consolidate metric code. Make it clear why InterrubtibleIterator is needed.
5c30405 [Matt Massie] Return visibility of BlockStoreShuffleFetcher to private[hash]
7eedd1d [Matt Massie] Small Scala import cleanup
28f8085 [Matt Massie] Small import nit
f93841e [Matt Massie] Update shuffle read metrics in ShuffleReader instead of BlockStoreShuffleFetcher.
7e8e0fe [Matt Massie] Minor Scala style fixes
01e8721 [Matt Massie] Explicitly cast iterator in branches for type clarity
7c8f73e [Matt Massie] Close Block InputStream immediately after all records are read
208b7a5 [Matt Massie] Small code style changes
b70c945 [Matt Massie] Make BlockStoreShuffleFetcher visible to shuffle package
19135f2 [Matt Massie] [SPARK-7884] Allow Spark shuffle APIs to be more customizable
  • Loading branch information
massie authored and kayousterhout committed Jun 25, 2015
1 parent 82f80c1 commit 7bac2fe
Show file tree
Hide file tree
Showing 5 changed files with 314 additions and 96 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,29 +17,29 @@

package org.apache.spark.shuffle.hash

import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import scala.util.{Failure, Success, Try}
import java.io.InputStream

import scala.collection.mutable.{ArrayBuffer, HashMap}
import scala.util.{Failure, Success}

import org.apache.spark._
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockFetcherIterator, ShuffleBlockId}
import org.apache.spark.util.CompletionIterator
import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, ShuffleBlockFetcherIterator,
ShuffleBlockId}

private[hash] object BlockStoreShuffleFetcher extends Logging {
def fetch[T](
def fetchBlockStreams(
shuffleId: Int,
reduceId: Int,
context: TaskContext,
serializer: Serializer)
: Iterator[T] =
blockManager: BlockManager,
mapOutputTracker: MapOutputTracker)
: Iterator[(BlockId, InputStream)] =
{
logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId))
val blockManager = SparkEnv.get.blockManager

val startTime = System.currentTimeMillis
val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, reduceId)
val statuses = mapOutputTracker.getServerStatuses(shuffleId, reduceId)
logDebug("Fetching map output location for shuffle %d, reduce %d took %d ms".format(
shuffleId, reduceId, System.currentTimeMillis - startTime))

Expand All @@ -53,12 +53,21 @@ private[hash] object BlockStoreShuffleFetcher extends Logging {
(address, splits.map(s => (ShuffleBlockId(shuffleId, s._1, reduceId), s._2)))
}

def unpackBlock(blockPair: (BlockId, Try[Iterator[Any]])) : Iterator[T] = {
val blockFetcherItr = new ShuffleBlockFetcherIterator(
context,
blockManager.shuffleClient,
blockManager,
blocksByAddress,
// Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024)

// Make sure that fetch failures are wrapped inside a FetchFailedException for the scheduler
blockFetcherItr.map { blockPair =>
val blockId = blockPair._1
val blockOption = blockPair._2
blockOption match {
case Success(block) => {
block.asInstanceOf[Iterator[T]]
case Success(inputStream) => {
(blockId, inputStream)
}
case Failure(e) => {
blockId match {
Expand All @@ -72,27 +81,5 @@ private[hash] object BlockStoreShuffleFetcher extends Logging {
}
}
}

val blockFetcherItr = new ShuffleBlockFetcherIterator(
context,
SparkEnv.get.blockManager.shuffleClient,
blockManager,
blocksByAddress,
serializer,
// Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024)
val itr = blockFetcherItr.flatMap(unpackBlock)

val completionIter = CompletionIterator[T, Iterator[T]](itr, {
context.taskMetrics.updateShuffleReadMetrics()
})

new InterruptibleIterator[T](context, completionIter) {
val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()
override def next(): T = {
readMetrics.incRecordsRead(1)
delegate.next()
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,20 @@

package org.apache.spark.shuffle.hash

import org.apache.spark.{InterruptibleIterator, TaskContext}
import org.apache.spark.{InterruptibleIterator, MapOutputTracker, SparkEnv, TaskContext}
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader}
import org.apache.spark.storage.BlockManager
import org.apache.spark.util.CompletionIterator
import org.apache.spark.util.collection.ExternalSorter

private[spark] class HashShuffleReader[K, C](
handle: BaseShuffleHandle[K, _, C],
startPartition: Int,
endPartition: Int,
context: TaskContext)
context: TaskContext,
blockManager: BlockManager = SparkEnv.get.blockManager,
mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker)
extends ShuffleReader[K, C]
{
require(endPartition == startPartition + 1,
Expand All @@ -36,20 +40,52 @@ private[spark] class HashShuffleReader[K, C](

/** Read the combined key-values for this reduce task */
override def read(): Iterator[Product2[K, C]] = {
val blockStreams = BlockStoreShuffleFetcher.fetchBlockStreams(
handle.shuffleId, startPartition, context, blockManager, mapOutputTracker)

// Wrap the streams for compression based on configuration
val wrappedStreams = blockStreams.map { case (blockId, inputStream) =>
blockManager.wrapForCompression(blockId, inputStream)
}

val ser = Serializer.getSerializer(dep.serializer)
val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, ser)
val serializerInstance = ser.newInstance()

// Create a key/value iterator for each stream
val recordIter = wrappedStreams.flatMap { wrappedStream =>
// Note: the asKeyValueIterator below wraps a key/value iterator inside of a
// NextIterator. The NextIterator makes sure that close() is called on the
// underlying InputStream when all records have been read.
serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
}

// Update the context task metrics for each record read.
val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()
val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
recordIter.map(record => {
readMetrics.incRecordsRead(1)
record
}),
context.taskMetrics().updateShuffleReadMetrics())

// An interruptible iterator must be used here in order to support task cancellation
val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)

val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
if (dep.mapSideCombine) {
new InterruptibleIterator(context, dep.aggregator.get.combineCombinersByKey(iter, context))
// We are reading values that are already combined
val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
} else {
new InterruptibleIterator(context, dep.aggregator.get.combineValuesByKey(iter, context))
// We don't know the value type, but also don't care -- the dependency *should*
// have made sure its compatible w/ this aggregator, which will convert the value
// type to the combined type C
val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
}
} else {
require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")

// Convert the Product2s to pairs since this is what downstream RDDs currently expect
iter.asInstanceOf[Iterator[Product2[K, C]]].map(pair => (pair._1, pair._2))
interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
}

// Sort the output if there is a sort ordering defined.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,23 @@

package org.apache.spark.storage

import java.io.InputStream
import java.util.concurrent.LinkedBlockingQueue

import scala.collection.mutable.{ArrayBuffer, HashSet, Queue}
import scala.util.{Failure, Try}

import org.apache.spark.{Logging, TaskContext}
import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient}
import org.apache.spark.network.buffer.ManagedBuffer
import org.apache.spark.serializer.{SerializerInstance, Serializer}
import org.apache.spark.util.{CompletionIterator, Utils}
import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient}
import org.apache.spark.util.Utils

/**
* An iterator that fetches multiple blocks. For local blocks, it fetches from the local block
* manager. For remote blocks, it fetches them using the provided BlockTransferService.
*
* This creates an iterator of (BlockID, values) tuples so the caller can handle blocks in a
* pipelined fashion as they are received.
* This creates an iterator of (BlockID, Try[InputStream]) tuples so the caller can handle blocks
* in a pipelined fashion as they are received.
*
* The implementation throttles the remote fetches to they don't exceed maxBytesInFlight to avoid
* using too much memory.
Expand All @@ -44,7 +44,6 @@ import org.apache.spark.util.{CompletionIterator, Utils}
* @param blocksByAddress list of blocks to fetch grouped by the [[BlockManagerId]].
* For each block we also require the size (in bytes as a long field) in
* order to throttle the memory usage.
* @param serializer serializer used to deserialize the data.
* @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point.
*/
private[spark]
Expand All @@ -53,9 +52,8 @@ final class ShuffleBlockFetcherIterator(
shuffleClient: ShuffleClient,
blockManager: BlockManager,
blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])],
serializer: Serializer,
maxBytesInFlight: Long)
extends Iterator[(BlockId, Try[Iterator[Any]])] with Logging {
extends Iterator[(BlockId, Try[InputStream])] with Logging {

import ShuffleBlockFetcherIterator._

Expand Down Expand Up @@ -83,7 +81,7 @@ final class ShuffleBlockFetcherIterator(

/**
* A queue to hold our results. This turns the asynchronous model provided by
* [[BlockTransferService]] into a synchronous model (iterator).
* [[org.apache.spark.network.BlockTransferService]] into a synchronous model (iterator).
*/
private[this] val results = new LinkedBlockingQueue[FetchResult]

Expand All @@ -102,9 +100,7 @@ final class ShuffleBlockFetcherIterator(
/** Current bytes in flight from our requests */
private[this] var bytesInFlight = 0L

private[this] val shuffleMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()

private[this] val serializerInstance: SerializerInstance = serializer.newInstance()
private[this] val shuffleMetrics = context.taskMetrics().createShuffleReadMetricsForDependency()

/**
* Whether the iterator is still active. If isZombie is true, the callback interface will no
Expand All @@ -114,17 +110,23 @@ final class ShuffleBlockFetcherIterator(

initialize()

/**
* Mark the iterator as zombie, and release all buffers that haven't been deserialized yet.
*/
private[this] def cleanup() {
isZombie = true
// Decrements the buffer reference count.
// The currentResult is set to null to prevent releasing the buffer again on cleanup()
private[storage] def releaseCurrentResultBuffer(): Unit = {
// Release the current buffer if necessary
currentResult match {
case SuccessFetchResult(_, _, buf) => buf.release()
case _ =>
}
currentResult = null
}

/**
* Mark the iterator as zombie, and release all buffers that haven't been deserialized yet.
*/
private[this] def cleanup() {
isZombie = true
releaseCurrentResultBuffer()
// Release buffers in the results queue
val iter = results.iterator()
while (iter.hasNext) {
Expand Down Expand Up @@ -272,7 +274,13 @@ final class ShuffleBlockFetcherIterator(

override def hasNext: Boolean = numBlocksProcessed < numBlocksToFetch

override def next(): (BlockId, Try[Iterator[Any]]) = {
/**
* Fetches the next (BlockId, Try[InputStream]). If a task fails, the ManagedBuffers
* underlying each InputStream will be freed by the cleanup() method registered with the
* TaskCompletionListener. However, callers should close() these InputStreams
* as soon as they are no longer needed, in order to release memory as early as possible.
*/
override def next(): (BlockId, Try[InputStream]) = {
numBlocksProcessed += 1
val startFetchWait = System.currentTimeMillis()
currentResult = results.take()
Expand All @@ -290,29 +298,55 @@ final class ShuffleBlockFetcherIterator(
sendRequest(fetchRequests.dequeue())
}

val iteratorTry: Try[Iterator[Any]] = result match {
val iteratorTry: Try[InputStream] = result match {
case FailureFetchResult(_, e) =>
Failure(e)
case SuccessFetchResult(blockId, _, buf) =>
// There is a chance that createInputStream can fail (e.g. fetching a local file that does
// not exist, SPARK-4085). In that case, we should propagate the right exception so
// the scheduler gets a FetchFailedException.
Try(buf.createInputStream()).map { is0 =>
val is = blockManager.wrapForCompression(blockId, is0)
val iter = serializerInstance.deserializeStream(is).asKeyValueIterator
CompletionIterator[Any, Iterator[Any]](iter, {
// Once the iterator is exhausted, release the buffer and set currentResult to null
// so we don't release it again in cleanup.
currentResult = null
buf.release()
})
Try(buf.createInputStream()).map { inputStream =>
new BufferReleasingInputStream(inputStream, this)
}
}

(result.blockId, iteratorTry)
}
}

/**
* Helper class that ensures a ManagedBuffer is release upon InputStream.close()
*/
private class BufferReleasingInputStream(
private val delegate: InputStream,
private val iterator: ShuffleBlockFetcherIterator)
extends InputStream {
private[this] var closed = false

override def read(): Int = delegate.read()

override def close(): Unit = {
if (!closed) {
delegate.close()
iterator.releaseCurrentResultBuffer()
closed = true
}
}

override def available(): Int = delegate.available()

override def mark(readlimit: Int): Unit = delegate.mark(readlimit)

override def skip(n: Long): Long = delegate.skip(n)

override def markSupported(): Boolean = delegate.markSupported()

override def read(b: Array[Byte]): Int = delegate.read(b)

override def read(b: Array[Byte], off: Int, len: Int): Int = delegate.read(b, off, len)

override def reset(): Unit = delegate.reset()
}

private[storage]
object ShuffleBlockFetcherIterator {
Expand Down
Loading

0 comments on commit 7bac2fe

Please sign in to comment.