Skip to content

Commit

Permalink
[SPARK-4105][SC-5349][BACKPORT] retry the fetch or stage if shuffle b…
Browse files Browse the repository at this point in the history
…lock is corrupt

(backport from upstream master to databricks branch 2.1)

## What changes were proposed in this pull request?

There is an outstanding issue that existed for a long time: Sometimes the shuffle blocks are corrupt and can't be decompressed. We recently hit this in three different workloads, sometimes we can reproduce it by every try, sometimes can't. I also found that when the corruption happened, the beginning and end of the blocks are correct, the corruption happen in the middle. There was one case that the string of block id is corrupt by one character. It seems that it's very likely the corruption is introduced by some weird machine/hardware, also the checksum (16 bits) in TCP is not strong enough to identify all the corruption.

Unfortunately, Spark does not have checksum for shuffle blocks or broadcast, the job will fail if any corruption happen in the shuffle block from disk, or broadcast blocks during network. This PR try to detect the corruption after fetching shuffle blocks by decompressing them, because most of the compression already have checksum in them. It will retry the block, or failed with FetchFailure, so the previous stage could be retried on different (still random) machines.

Checksum for broadcast will be added by another PR.

## How was this patch tested?

Added unit tests

Author: Davies Liu <daviesdatabricks.com>

Closes apache#15923 from davies/detect_corrupt.

Author: Davies Liu <davies@databricks.com>

Closes apache#159 from ericl/sc-5362.
  • Loading branch information
Davies Liu authored and yhuai committed Jan 24, 2017
1 parent 671c5f6 commit b333054
Show file tree
Hide file tree
Showing 4 changed files with 263 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,24 +42,21 @@ private[spark] class BlockStoreShuffleReader[K, C](

/** Read the combined key-values for this reduce task */
override def read(): Iterator[Product2[K, C]] = {
val blockFetcherItr = new ShuffleBlockFetcherIterator(
val wrappedStreams = new ShuffleBlockFetcherIterator(
context,
blockManager.shuffleClient,
blockManager,
mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),
serializerManager.wrapStream,
// Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024,
SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue))

// Wrap the streams for compression and encryption based on configuration
val wrappedStreams = blockFetcherItr.map { case (blockId, inputStream) =>
serializerManager.wrapStream(blockId, inputStream)
}
SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue),
SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true))

val serializerInstance = dep.serializer.newInstance()

// Create a key/value iterator for each stream
val recordIter = wrappedStreams.flatMap { wrappedStream =>
val recordIter = wrappedStreams.flatMap { case (blockId, wrappedStream) =>
// Note: the asKeyValueIterator below wraps a key/value iterator inside of a
// NextIterator. The NextIterator makes sure that close() is called on the
// underlying InputStream when all records have been read.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,21 @@

package org.apache.spark.storage

import java.io.InputStream
import java.io.{InputStream, IOException}
import java.nio.ByteBuffer
import java.util.concurrent.LinkedBlockingQueue
import javax.annotation.concurrent.GuardedBy

import scala.collection.mutable
import scala.collection.mutable.{ArrayBuffer, HashSet, Queue}
import scala.util.control.NonFatal

import org.apache.spark.{SparkException, TaskContext}
import org.apache.spark.internal.Logging
import org.apache.spark.network.buffer.ManagedBuffer
import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient}
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.util.Utils
import org.apache.spark.util.io.ChunkedByteBufferOutputStream

/**
* An iterator that fetches multiple blocks. For local blocks, it fetches from the local block
Expand All @@ -47,17 +49,21 @@ import org.apache.spark.util.Utils
* @param blocksByAddress list of blocks to fetch grouped by the [[BlockManagerId]].
* For each block we also require the size (in bytes as a long field) in
* order to throttle the memory usage.
* @param streamWrapper A function to wrap the returned input stream.
* @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point.
* @param maxReqsInFlight max number of remote requests to fetch blocks at any given point.
* @param detectCorrupt whether to detect any corruption in fetched blocks.
*/
private[spark]
final class ShuffleBlockFetcherIterator(
context: TaskContext,
shuffleClient: ShuffleClient,
blockManager: BlockManager,
blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])],
streamWrapper: (BlockId, InputStream) => InputStream,
maxBytesInFlight: Long,
maxReqsInFlight: Int)
maxReqsInFlight: Int,
detectCorrupt: Boolean)
extends Iterator[(BlockId, InputStream)] with Logging {

import ShuffleBlockFetcherIterator._
Expand Down Expand Up @@ -94,7 +100,7 @@ final class ShuffleBlockFetcherIterator(
* Current [[FetchResult]] being processed. We track this so we can release the current buffer
* in case of a runtime exception when processing the current buffer.
*/
@volatile private[this] var currentResult: FetchResult = null
@volatile private[this] var currentResult: SuccessFetchResult = null

/**
* Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that
Expand All @@ -108,6 +114,12 @@ final class ShuffleBlockFetcherIterator(
/** Current number of requests in flight */
private[this] var reqsInFlight = 0

/**
* The blocks that can't be decompressed successfully, it is used to guarantee that we retry
* at most once for those corrupted blocks.
*/
private[this] val corruptedBlocks = mutable.HashSet[BlockId]()

private[this] val shuffleMetrics = context.taskMetrics().createTempShuffleReadMetrics()

/**
Expand All @@ -123,9 +135,8 @@ final class ShuffleBlockFetcherIterator(
// The currentResult is set to null to prevent releasing the buffer again on cleanup()
private[storage] def releaseCurrentResultBuffer(): Unit = {
// Release the current buffer if necessary
currentResult match {
case SuccessFetchResult(_, _, _, buf, _) => buf.release()
case _ =>
if (currentResult != null) {
currentResult.buf.release()
}
currentResult = null
}
Expand Down Expand Up @@ -309,40 +320,84 @@ final class ShuffleBlockFetcherIterator(
}

numBlocksProcessed += 1
val startFetchWait = System.currentTimeMillis()
currentResult = results.take()
val result = currentResult
val stopFetchWait = System.currentTimeMillis()
shuffleMetrics.incFetchWaitTime(stopFetchWait - startFetchWait)

result match {
case SuccessFetchResult(_, address, size, buf, isNetworkReqDone) =>
if (address != blockManager.blockManagerId) {
shuffleMetrics.incRemoteBytesRead(buf.size)
shuffleMetrics.incRemoteBlocksFetched(1)
}
bytesInFlight -= size
if (isNetworkReqDone) {
reqsInFlight -= 1
logDebug("Number of requests in flight " + reqsInFlight)
}
case _ =>
}
// Send fetch requests up to maxBytesInFlight
fetchUpToMaxBytes()

result match {
case FailureFetchResult(blockId, address, e) =>
throwFetchFailedException(blockId, address, e)
var result: FetchResult = null
var input: InputStream = null
// Take the next fetched result and try to decompress it to detect data corruption,
// then fetch it one more time if it's corrupt, throw FailureFetchResult if the second fetch
// is also corrupt, so the previous stage could be retried.
// For local shuffle block, throw FailureFetchResult for the first IOException.
while (result == null) {
val startFetchWait = System.currentTimeMillis()
result = results.take()
val stopFetchWait = System.currentTimeMillis()
shuffleMetrics.incFetchWaitTime(stopFetchWait - startFetchWait)

case SuccessFetchResult(blockId, address, _, buf, _) =>
try {
(result.blockId, new BufferReleasingInputStream(buf.createInputStream(), this))
} catch {
case NonFatal(t) =>
throwFetchFailedException(blockId, address, t)
}
result match {
case r @ SuccessFetchResult(blockId, address, size, buf, isNetworkReqDone) =>
if (address != blockManager.blockManagerId) {
shuffleMetrics.incRemoteBytesRead(buf.size)
shuffleMetrics.incRemoteBlocksFetched(1)
}
bytesInFlight -= size
if (isNetworkReqDone) {
reqsInFlight -= 1
logDebug("Number of requests in flight " + reqsInFlight)
}

val in = try {
buf.createInputStream()
} catch {
// The exception could only be throwed by local shuffle block
case e: IOException =>
assert(buf.isInstanceOf[FileSegmentManagedBuffer])
logError("Failed to create input stream from local block", e)
buf.release()
throwFetchFailedException(blockId, address, e)
}

input = streamWrapper(blockId, in)
// Only copy the stream if it's wrapped by compression or encryption, also the size of
// block is small (the decompressed block is smaller than maxBytesInFlight)
if (detectCorrupt && !input.eq(in) && size < maxBytesInFlight / 3) {
val originalInput = input
val out = new ChunkedByteBufferOutputStream(64 * 1024, ByteBuffer.allocate)
try {
// Decompress the whole block at once to detect any corruption, which could increase
// the memory usage tne potential increase the chance of OOM.
// TODO: manage the memory used here, and spill it into disk in case of OOM.
Utils.copyStream(input, out)
out.close()
input = out.toChunkedByteBuffer.toInputStream(dispose = true)
} catch {
case e: IOException =>
buf.release()
if (buf.isInstanceOf[FileSegmentManagedBuffer]
|| corruptedBlocks.contains(blockId)) {
throwFetchFailedException(blockId, address, e)
} else {
logWarning(s"got an corrupted block $blockId from $address, fetch again", e)
corruptedBlocks += blockId
fetchRequests += FetchRequest(address, Array((blockId, size)))
result = null
}
} finally {
// TODO: release the buf here to free memory earlier
originalInput.close()
in.close()
}
}

case FailureFetchResult(blockId, address, e) =>
throwFetchFailedException(blockId, address, e)
}

// Send fetch requests up to maxBytesInFlight
fetchUpToMaxBytes()
}

currentResult = result.asInstanceOf[SuccessFetchResult]
(currentResult.blockId, new BufferReleasingInputStream(input, this))
}

private def fetchUpToMaxBytes(): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) {
* @param dispose if true, `ChunkedByteBuffer.dispose()` will be called at the end of the stream
* in order to close any memory-mapped files which back the buffer.
*/
private class ChunkedByteBufferInputStream(
private[spark] class ChunkedByteBufferInputStream(
var chunkedByteBuffer: ChunkedByteBuffer,
dispose: Boolean)
extends InputStream {
Expand Down
Loading

0 comments on commit b333054

Please sign in to comment.