Skip to content

Commit

Permalink
[SPARK-7884] Allow Spark shuffle APIs to be more customizable
Browse files Browse the repository at this point in the history
This commit updates the shuffle read path to enable ShuffleReader
implementations more control over the deserialization process.

The BlockStoreShuffleFetcher.fetch() method has been renamed to
BlockStoreShuffleFetcher.fetchBlockStreams(). Previously, this method
returned a record iterator; now, it returns an iterator of
(BlockId, Try[InputStream]). Deserialization of records is now handled
in the ShuffleReader.read() method.

This change creates a cleaner separation of concerns and allows
implementations of ShuffleReader more flexibility in how records
are deserialized.
  • Loading branch information
massie committed Jun 2, 2015
1 parent 6396cc0 commit 19135f2
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 71 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,22 @@

package org.apache.spark.shuffle.hash

import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import java.io.InputStream

import scala.collection.mutable.{ArrayBuffer, HashMap}
import scala.util.{Failure, Success, Try}

import org.apache.spark._
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockFetcherIterator, ShuffleBlockId}
import org.apache.spark.util.CompletionIterator

private[hash] object BlockStoreShuffleFetcher extends Logging {
def fetch[T](
def fetchBlockStreams(
shuffleId: Int,
reduceId: Int,
context: TaskContext,
serializer: Serializer)
: Iterator[T] =
context: TaskContext)
: Iterator[(BlockId, InputStream)] =
{
logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId))
val blockManager = SparkEnv.get.blockManager
Expand All @@ -53,12 +52,12 @@ private[hash] object BlockStoreShuffleFetcher extends Logging {
(address, splits.map(s => (ShuffleBlockId(shuffleId, s._1, reduceId), s._2)))
}

def unpackBlock(blockPair: (BlockId, Try[Iterator[Any]])) : Iterator[T] = {
def unpackBlock(blockPair: (BlockId, Try[InputStream])) : (BlockId, InputStream) = {
val blockId = blockPair._1
val blockOption = blockPair._2
blockOption match {
case Success(block) => {
block.asInstanceOf[Iterator[T]]
case Success(inputStream) => {
(blockId, inputStream)
}
case Failure(e) => {
blockId match {
Expand All @@ -78,21 +77,13 @@ private[hash] object BlockStoreShuffleFetcher extends Logging {
SparkEnv.get.blockManager.shuffleClient,
blockManager,
blocksByAddress,
serializer,
// Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024)
val itr = blockFetcherItr.flatMap(unpackBlock)

val completionIter = CompletionIterator[T, Iterator[T]](itr, {
context.taskMetrics.updateShuffleReadMetrics()
})
val itr = blockFetcherItr.map(unpackBlock)

new InterruptibleIterator[T](context, completionIter) {
val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()
override def next(): T = {
readMetrics.incRecordsRead(1)
delegate.next()
}
}
CompletionIterator[(BlockId, InputStream), Iterator[(BlockId, InputStream)]](itr, {
context.taskMetrics().updateShuffleReadMetrics()
})
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@

package org.apache.spark.shuffle.hash

import org.apache.spark.{InterruptibleIterator, TaskContext}
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader}
import org.apache.spark.util.collection.ExternalSorter
import org.apache.spark.{InterruptibleIterator, SparkEnv, TaskContext}

private[spark] class HashShuffleReader[K, C](
handle: BaseShuffleHandle[K, _, C],
Expand All @@ -33,11 +33,34 @@ private[spark] class HashShuffleReader[K, C](
"Hash shuffle currently only supports fetching one partition")

private val dep = handle.dependency
private val blockManager = SparkEnv.get.blockManager

/** Read the combined key-values for this reduce task */
override def read(): Iterator[Product2[K, C]] = {
val blockStreams = BlockStoreShuffleFetcher.fetchBlockStreams(
handle.shuffleId, startPartition, context)

// Wrap the streams for compression based on configuration
val wrappedStreams = blockStreams.map { case (blockId, inputStream) =>
blockManager.wrapForCompression(blockId, inputStream)
}

val ser = Serializer.getSerializer(dep.serializer)
val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, ser)
val serializerInstance = ser.newInstance()

// Create a key/value iterator for each stream
val recordIterator = wrappedStreams.flatMap { wrappedStream =>
serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
}

// Update read metrics for each record materialized
val iter = new InterruptibleIterator[Any](context, recordIterator) {
val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()
override def next(): Any = {
readMetrics.incRecordsRead(1)
delegate.next()
}
}.asInstanceOf[Iterator[Nothing]]

val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
if (dep.mapSideCombine) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,24 @@

package org.apache.spark.storage

import java.io.InputStream
import java.util.concurrent.LinkedBlockingQueue

import scala.collection.mutable.{ArrayBuffer, HashSet, Queue}
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.util.{Failure, Try}

import org.apache.spark.{Logging, TaskContext}
import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient}
import org.apache.spark.network.buffer.ManagedBuffer
import org.apache.spark.serializer.{SerializerInstance, Serializer}
import org.apache.spark.util.{CompletionIterator, Utils}
import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient}
import org.apache.spark.util.Utils
import org.apache.spark.{Logging, TaskContext}

/**
* An iterator that fetches multiple blocks. For local blocks, it fetches from the local block
* manager. For remote blocks, it fetches them using the provided BlockTransferService.
*
* This creates an iterator of (BlockID, values) tuples so the caller can handle blocks in a
* pipelined fashion as they are received.
* This creates an iterator of (BlockID, Try[InputStream]) tuples so the caller can handle blocks
* in a pipelined fashion as they are received.
*
* The implementation throttles the remote fetches to they don't exceed maxBytesInFlight to avoid
* using too much memory.
Expand All @@ -44,7 +45,6 @@ import org.apache.spark.util.{CompletionIterator, Utils}
* @param blocksByAddress list of blocks to fetch grouped by the [[BlockManagerId]].
* For each block we also require the size (in bytes as a long field) in
* order to throttle the memory usage.
* @param serializer serializer used to deserialize the data.
* @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point.
*/
private[spark]
Expand All @@ -53,9 +53,8 @@ final class ShuffleBlockFetcherIterator(
shuffleClient: ShuffleClient,
blockManager: BlockManager,
blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])],
serializer: Serializer,
maxBytesInFlight: Long)
extends Iterator[(BlockId, Try[Iterator[Any]])] with Logging {
extends Iterator[(BlockId, Try[InputStream])] with Logging {

import ShuffleBlockFetcherIterator._

Expand All @@ -79,11 +78,11 @@ final class ShuffleBlockFetcherIterator(
private[this] val localBlocks = new ArrayBuffer[BlockId]()

/** Remote blocks to fetch, excluding zero-sized blocks. */
private[this] val remoteBlocks = new HashSet[BlockId]()
private[this] val remoteBlocks = new mutable.HashSet[BlockId]()

/**
* A queue to hold our results. This turns the asynchronous model provided by
* [[BlockTransferService]] into a synchronous model (iterator).
* [[org.apache.spark.network.BlockTransferService]] into a synchronous model (iterator).
*/
private[this] val results = new LinkedBlockingQueue[FetchResult]

Expand All @@ -97,14 +96,12 @@ final class ShuffleBlockFetcherIterator(
* Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that
* the number of bytes in flight is limited to maxBytesInFlight.
*/
private[this] val fetchRequests = new Queue[FetchRequest]
private[this] val fetchRequests = new mutable.Queue[FetchRequest]

/** Current bytes in flight from our requests */
private[this] var bytesInFlight = 0L

private[this] val shuffleMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()

private[this] val serializerInstance: SerializerInstance = serializer.newInstance()
private[this] val shuffleMetrics = context.taskMetrics().createShuffleReadMetricsForDependency()

/**
* Whether the iterator is still active. If isZombie is true, the callback interface will no
Expand All @@ -114,17 +111,23 @@ final class ShuffleBlockFetcherIterator(

initialize()

/**
* Mark the iterator as zombie, and release all buffers that haven't been deserialized yet.
*/
private[this] def cleanup() {
isZombie = true
// Decrements the buffer reference count.
// The currentResult is set to null to prevent releasing the buffer again on cleanup()
private[storage] def releaseCurrentResultBuffer(): Unit = {
// Release the current buffer if necessary
currentResult match {
case SuccessFetchResult(_, _, buf) => buf.release()
case _ =>
}
currentResult = null
}

/**
* Mark the iterator as zombie, and release all buffers that haven't been deserialized yet.
*/
private[this] def cleanup() {
isZombie = true
releaseCurrentResultBuffer()
// Release buffers in the results queue
val iter = results.iterator()
while (iter.hasNext) {
Expand Down Expand Up @@ -272,7 +275,7 @@ final class ShuffleBlockFetcherIterator(

override def hasNext: Boolean = numBlocksProcessed < numBlocksToFetch

override def next(): (BlockId, Try[Iterator[Any]]) = {
override def next(): (BlockId, Try[InputStream]) = {
numBlocksProcessed += 1
val startFetchWait = System.currentTimeMillis()
currentResult = results.take()
Expand All @@ -290,29 +293,51 @@ final class ShuffleBlockFetcherIterator(
sendRequest(fetchRequests.dequeue())
}

val iteratorTry: Try[Iterator[Any]] = result match {
val iteratorTry: Try[InputStream] = result match {
case FailureFetchResult(_, e) =>
Failure(e)
case SuccessFetchResult(blockId, _, buf) =>
// There is a chance that createInputStream can fail (e.g. fetching a local file that does
// not exist, SPARK-4085). In that case, we should propagate the right exception so
// the scheduler gets a FetchFailedException.
Try(buf.createInputStream()).map { is0 =>
val is = blockManager.wrapForCompression(blockId, is0)
val iter = serializerInstance.deserializeStream(is).asKeyValueIterator
CompletionIterator[Any, Iterator[Any]](iter, {
// Once the iterator is exhausted, release the buffer and set currentResult to null
// so we don't release it again in cleanup.
currentResult = null
buf.release()
})
Try(buf.createInputStream()).map { inputStream =>
new WrappedInputStream(inputStream, this)
}
}

(result.blockId, iteratorTry)
}
}

// Helper class that ensures a ManagerBuffer is released upon InputStream.close()
private class WrappedInputStream(delegate: InputStream, iterator: ShuffleBlockFetcherIterator)
extends InputStream {
private var closed = false

override def read(): Int = delegate.read()

override def close(): Unit = {
if (!closed) {
delegate.close()
iterator.releaseCurrentResultBuffer()
closed = true
}
}

override def available(): Int = delegate.available()

override def mark(readlimit: Int): Unit = delegate.mark(readlimit)

override def skip(n: Long): Long = delegate.skip(n)

override def markSupported(): Boolean = delegate.markSupported()

override def read(b: Array[Byte]): Int = delegate.read(b)

override def read(b: Array[Byte], off: Int, len: Int): Int = delegate.read(b, off, len)

override def reset(): Unit = delegate.reset()
}

private[storage]
object ShuffleBlockFetcherIterator {
Expand Down
Loading

0 comments on commit 19135f2

Please sign in to comment.