Skip to content

Commit

Permalink
Add workaround for transferTo() bug in merging code; refactor tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshRosen committed May 8, 2015
1 parent 9883e30 commit 722849b
Show file tree
Hide file tree
Showing 2 changed files with 225 additions and 94 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,30 +25,37 @@
import java.nio.channels.FileChannel;
import java.util.Iterator;

import org.apache.spark.shuffle.ShuffleMemoryManager;
import scala.Option;
import scala.Product2;
import scala.collection.JavaConversions;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;

import com.esotericsoftware.kryo.io.ByteBufferOutputStream;
import com.google.common.io.ByteStreams;
import com.google.common.io.Files;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.spark.*;
import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.network.util.LimitedInputStream;
import org.apache.spark.scheduler.MapStatus;
import org.apache.spark.scheduler.MapStatus$;
import org.apache.spark.serializer.SerializationStream;
import org.apache.spark.serializer.Serializer;
import org.apache.spark.serializer.SerializerInstance;
import org.apache.spark.shuffle.IndexShuffleBlockManager;
import org.apache.spark.shuffle.ShuffleMemoryManager;
import org.apache.spark.shuffle.ShuffleWriter;
import org.apache.spark.storage.BlockManager;
import org.apache.spark.unsafe.PlatformDependent;
import org.apache.spark.unsafe.memory.TaskMemoryManager;

public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {

private final Logger logger = LoggerFactory.getLogger(UnsafeShuffleWriter.class);

private static final int SER_BUFFER_SIZE = 1024 * 1024; // TODO: tune this
private static final ClassTag<Object> OBJECT_CLASS_TAG = ClassTag$.MODULE$.Object();

Expand All @@ -63,6 +70,7 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
private final int mapId;
private final TaskContext taskContext;
private final SparkConf sparkConf;
private final boolean transferToEnabled;

private MapStatus mapStatus = null;

Expand Down Expand Up @@ -95,6 +103,7 @@ public UnsafeShuffleWriter(
taskContext.taskMetrics().shuffleWriteMetrics_$eq(Option.apply(writeMetrics));
this.taskContext = taskContext;
this.sparkConf = sparkConf;
this.transferToEnabled = sparkConf.getBoolean("spark.file.transferTo", true);
}

public void write(Iterator<Product2<K, V>> records) {
Expand All @@ -116,6 +125,10 @@ private void freeMemory() {
// TODO
}

private void deleteSpills() {
// TODO
}

private SpillInfo[] insertRecordsIntoSorter(
scala.collection.Iterator<? extends Product2<K, V>> records) throws Exception {
final UnsafeShuffleExternalSorter sorter = new UnsafeShuffleExternalSorter(
Expand Down Expand Up @@ -154,55 +167,127 @@ private SpillInfo[] insertRecordsIntoSorter(

private long[] mergeSpills(SpillInfo[] spills) throws IOException {
final File outputFile = shuffleBlockManager.getDataFile(shuffleId, mapId);
try {
if (spills.length == 0) {
new FileOutputStream(outputFile).close(); // Create an empty file
return new long[partitioner.numPartitions()];
} else if (spills.length == 1) {
// Note: we'll have to watch out for corner-cases in this code path when working on shuffle
// metrics integration, since any metrics updates that are performed during the merge will
// also have to be done here. In this branch, the shuffle technically didn't need to spill
// because we're only trying to merge one file, so we may need to ensure that metrics that
// would otherwise be counted as spill metrics are actually counted as regular write
// metrics.
Files.move(spills[0].file, outputFile);
return spills[0].partitionLengths;
} else {
// Need to merge multiple spills.
if (transferToEnabled) {
return mergeSpillsWithTransferTo(spills, outputFile);
} else {
return mergeSpillsWithFileStream(spills, outputFile);
}
}
} catch (IOException e) {
if (outputFile.exists() && !outputFile.delete()) {
logger.error("Unable to delete output file {}", outputFile.getPath());
}
throw e;
}
}

private long[] mergeSpillsWithFileStream(SpillInfo[] spills, File outputFile) throws IOException {
final int numPartitions = partitioner.numPartitions();
final long[] partitionLengths = new long[numPartitions];
final FileInputStream[] spillInputStreams = new FileInputStream[spills.length];
FileOutputStream mergedFileOutputStream = null;

try {
for (int i = 0; i < spills.length; i++) {
spillInputStreams[i] = new FileInputStream(spills[i].file);
}
mergedFileOutputStream = new FileOutputStream(outputFile);

if (spills.length == 0) {
new FileOutputStream(outputFile).close();
return partitionLengths;
for (int partition = 0; partition < numPartitions; partition++) {
for (int i = 0; i < spills.length; i++) {
final long partitionLengthInSpill = spills[i].partitionLengths[partition];
final FileInputStream spillInputStream = spillInputStreams[i];
ByteStreams.copy
(new LimitedInputStream(spillInputStream, partitionLengthInSpill),
mergedFileOutputStream);
partitionLengths[partition] += partitionLengthInSpill;
}
}
} finally {
for (int i = 0; i < spills.length; i++) {
if (spillInputStreams[i] != null) {
spillInputStreams[i].close();
}
}
if (mergedFileOutputStream != null) {
mergedFileOutputStream.close();
}
}
return partitionLengths;
}

private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) throws IOException {
final int numPartitions = partitioner.numPartitions();
final long[] partitionLengths = new long[numPartitions];
final FileChannel[] spillInputChannels = new FileChannel[spills.length];
final long[] spillInputChannelPositions = new long[spills.length];
FileChannel mergedFileOutputChannel = null;

// TODO: We need to add an option to bypass transferTo here since older Linux kernels are
// affected by a bug here that can lead to data truncation; see the comments Utils.scala,
// in the copyStream() method. I didn't use copyStream() here because we only want to copy
// a limited number of bytes from the stream and I didn't want to modify / extend that method
// to accept a length.

// TODO: special case optimization for case where we only write one file (non-spill case).

for (int i = 0; i < spills.length; i++) {
spillInputChannels[i] = new FileInputStream(spills[i].file).getChannel();
}

final FileChannel mergedFileOutputChannel = new FileOutputStream(outputFile).getChannel();

for (int partition = 0; partition < numPartitions; partition++) {
try {
for (int i = 0; i < spills.length; i++) {
final long partitionLengthInSpill = spills[i].partitionLengths[partition];
long bytesToTransfer = partitionLengthInSpill;
final FileChannel spillInputChannel = spillInputChannels[i];
while (bytesToTransfer > 0) {
final long actualBytesTransferred = spillInputChannel.transferTo(
spillInputChannels[i] = new FileInputStream(spills[i].file).getChannel();
}
// This file needs to opened in append mode in order to work around a Linux kernel bug that
// affects transferTo; see SPARK-3948 for more details.
mergedFileOutputChannel = new FileOutputStream(outputFile, true).getChannel();

long bytesWrittenToMergedFile = 0;
for (int partition = 0; partition < numPartitions; partition++) {
for (int i = 0; i < spills.length; i++) {
final long partitionLengthInSpill = spills[i].partitionLengths[partition];
long bytesToTransfer = partitionLengthInSpill;
final FileChannel spillInputChannel = spillInputChannels[i];
while (bytesToTransfer > 0) {
final long actualBytesTransferred = spillInputChannel.transferTo(
spillInputChannelPositions[i],
bytesToTransfer,
mergedFileOutputChannel);
spillInputChannelPositions[i] += actualBytesTransferred;
bytesToTransfer -= actualBytesTransferred;
spillInputChannelPositions[i] += actualBytesTransferred;
bytesToTransfer -= actualBytesTransferred;
}
bytesWrittenToMergedFile += partitionLengthInSpill;
partitionLengths[partition] += partitionLengthInSpill;
}
partitionLengths[partition] += partitionLengthInSpill;
}
// Check the position after transferTo loop to see if it is in the right position and raise an
// exception if it is incorrect. The position will not be increased to the expected length
// after calling transferTo in kernel version 2.6.32. This issue is described at
// https://bugs.openjdk.java.net/browse/JDK-7052359 and SPARK-3948.
if (mergedFileOutputChannel.position() != bytesWrittenToMergedFile) {
throw new IOException(
"Current position " + mergedFileOutputChannel.position() + " does not equal expected " +
"position " + bytesWrittenToMergedFile + " after transferTo. Please check your kernel" +
" version to see if it is 2.6.32, as there is a kernel bug which will lead to " +
"unexpected behavior when using transferTo. You can set spark.file.transferTo=false " +
"to disable this NIO feature."
);
}
} finally {
for (int i = 0; i < spills.length; i++) {
assert(spillInputChannelPositions[i] == spills[i].file.length());
if (spillInputChannels[i] != null) {
spillInputChannels[i].close();
}
}
if (mergedFileOutputChannel != null) {
mergedFileOutputChannel.close();
}
}

// TODO: should this be in a finally block?
for (int i = 0; i < spills.length; i++) {
assert(spillInputChannelPositions[i] == spills[i].file.length());
spillInputChannels[i].close();
}
mergedFileOutputChannel.close();

return partitionLengths;
}

Expand All @@ -215,6 +300,9 @@ public Option<MapStatus> stop(boolean success) {
stopping = true;
freeMemory();
if (success) {
if (mapStatus == null) {
throw new IllegalStateException("Cannot call stop(true) without having called write()");
}
return Option.apply(mapStatus);
} else {
// The map task failed, so delete our output data.
Expand Down
Loading

0 comments on commit 722849b

Please sign in to comment.