From 722849b3f77dcbf3494f6a76a219d7f29b7a8284 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 8 May 2015 16:31:03 -0700 Subject: [PATCH] Add workaround for transferTo() bug in merging code; refactor tests. --- .../shuffle/unsafe/UnsafeShuffleWriter.java | 158 +++++++++++++---- .../unsafe/UnsafeShuffleWriterSuite.java | 161 +++++++++++------- 2 files changed, 225 insertions(+), 94 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java index 1b5af45334238..206812f8352d2 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java @@ -25,7 +25,6 @@ import java.nio.channels.FileChannel; import java.util.Iterator; -import org.apache.spark.shuffle.ShuffleMemoryManager; import scala.Option; import scala.Product2; import scala.collection.JavaConversions; @@ -33,15 +32,21 @@ import scala.reflect.ClassTag$; import com.esotericsoftware.kryo.io.ByteBufferOutputStream; +import com.google.common.io.ByteStreams; +import com.google.common.io.Files; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import org.apache.spark.*; import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.network.util.LimitedInputStream; import org.apache.spark.scheduler.MapStatus; import org.apache.spark.scheduler.MapStatus$; import org.apache.spark.serializer.SerializationStream; import org.apache.spark.serializer.Serializer; import org.apache.spark.serializer.SerializerInstance; import org.apache.spark.shuffle.IndexShuffleBlockManager; +import org.apache.spark.shuffle.ShuffleMemoryManager; import org.apache.spark.shuffle.ShuffleWriter; import org.apache.spark.storage.BlockManager; import org.apache.spark.unsafe.PlatformDependent; @@ -49,6 +54,8 @@ public class UnsafeShuffleWriter extends ShuffleWriter { + private final Logger logger = LoggerFactory.getLogger(UnsafeShuffleWriter.class); + private static final int SER_BUFFER_SIZE = 1024 * 1024; // TODO: tune this private static final ClassTag OBJECT_CLASS_TAG = ClassTag$.MODULE$.Object(); @@ -63,6 +70,7 @@ public class UnsafeShuffleWriter extends ShuffleWriter { private final int mapId; private final TaskContext taskContext; private final SparkConf sparkConf; + private final boolean transferToEnabled; private MapStatus mapStatus = null; @@ -95,6 +103,7 @@ public UnsafeShuffleWriter( taskContext.taskMetrics().shuffleWriteMetrics_$eq(Option.apply(writeMetrics)); this.taskContext = taskContext; this.sparkConf = sparkConf; + this.transferToEnabled = sparkConf.getBoolean("spark.file.transferTo", true); } public void write(Iterator> records) { @@ -116,6 +125,10 @@ private void freeMemory() { // TODO } + private void deleteSpills() { + // TODO + } + private SpillInfo[] insertRecordsIntoSorter( scala.collection.Iterator> records) throws Exception { final UnsafeShuffleExternalSorter sorter = new UnsafeShuffleExternalSorter( @@ -154,55 +167,127 @@ private SpillInfo[] insertRecordsIntoSorter( private long[] mergeSpills(SpillInfo[] spills) throws IOException { final File outputFile = shuffleBlockManager.getDataFile(shuffleId, mapId); + try { + if (spills.length == 0) { + new FileOutputStream(outputFile).close(); // Create an empty file + return new long[partitioner.numPartitions()]; + } else if (spills.length == 1) { + // Note: we'll have to watch out for corner-cases in this code path when working on shuffle + // metrics integration, since any metrics updates that are performed during the merge will + // also have to be done here. In this branch, the shuffle technically didn't need to spill + // because we're only trying to merge one file, so we may need to ensure that metrics that + // would otherwise be counted as spill metrics are actually counted as regular write + // metrics. + Files.move(spills[0].file, outputFile); + return spills[0].partitionLengths; + } else { + // Need to merge multiple spills. + if (transferToEnabled) { + return mergeSpillsWithTransferTo(spills, outputFile); + } else { + return mergeSpillsWithFileStream(spills, outputFile); + } + } + } catch (IOException e) { + if (outputFile.exists() && !outputFile.delete()) { + logger.error("Unable to delete output file {}", outputFile.getPath()); + } + throw e; + } + } + + private long[] mergeSpillsWithFileStream(SpillInfo[] spills, File outputFile) throws IOException { final int numPartitions = partitioner.numPartitions(); final long[] partitionLengths = new long[numPartitions]; + final FileInputStream[] spillInputStreams = new FileInputStream[spills.length]; + FileOutputStream mergedFileOutputStream = null; + + try { + for (int i = 0; i < spills.length; i++) { + spillInputStreams[i] = new FileInputStream(spills[i].file); + } + mergedFileOutputStream = new FileOutputStream(outputFile); - if (spills.length == 0) { - new FileOutputStream(outputFile).close(); - return partitionLengths; + for (int partition = 0; partition < numPartitions; partition++) { + for (int i = 0; i < spills.length; i++) { + final long partitionLengthInSpill = spills[i].partitionLengths[partition]; + final FileInputStream spillInputStream = spillInputStreams[i]; + ByteStreams.copy + (new LimitedInputStream(spillInputStream, partitionLengthInSpill), + mergedFileOutputStream); + partitionLengths[partition] += partitionLengthInSpill; + } + } + } finally { + for (int i = 0; i < spills.length; i++) { + if (spillInputStreams[i] != null) { + spillInputStreams[i].close(); + } + } + if (mergedFileOutputStream != null) { + mergedFileOutputStream.close(); + } } + return partitionLengths; + } + private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) throws IOException { + final int numPartitions = partitioner.numPartitions(); + final long[] partitionLengths = new long[numPartitions]; final FileChannel[] spillInputChannels = new FileChannel[spills.length]; final long[] spillInputChannelPositions = new long[spills.length]; + FileChannel mergedFileOutputChannel = null; - // TODO: We need to add an option to bypass transferTo here since older Linux kernels are - // affected by a bug here that can lead to data truncation; see the comments Utils.scala, - // in the copyStream() method. I didn't use copyStream() here because we only want to copy - // a limited number of bytes from the stream and I didn't want to modify / extend that method - // to accept a length. - - // TODO: special case optimization for case where we only write one file (non-spill case). - - for (int i = 0; i < spills.length; i++) { - spillInputChannels[i] = new FileInputStream(spills[i].file).getChannel(); - } - - final FileChannel mergedFileOutputChannel = new FileOutputStream(outputFile).getChannel(); - - for (int partition = 0; partition < numPartitions; partition++) { + try { for (int i = 0; i < spills.length; i++) { - final long partitionLengthInSpill = spills[i].partitionLengths[partition]; - long bytesToTransfer = partitionLengthInSpill; - final FileChannel spillInputChannel = spillInputChannels[i]; - while (bytesToTransfer > 0) { - final long actualBytesTransferred = spillInputChannel.transferTo( + spillInputChannels[i] = new FileInputStream(spills[i].file).getChannel(); + } + // This file needs to opened in append mode in order to work around a Linux kernel bug that + // affects transferTo; see SPARK-3948 for more details. + mergedFileOutputChannel = new FileOutputStream(outputFile, true).getChannel(); + + long bytesWrittenToMergedFile = 0; + for (int partition = 0; partition < numPartitions; partition++) { + for (int i = 0; i < spills.length; i++) { + final long partitionLengthInSpill = spills[i].partitionLengths[partition]; + long bytesToTransfer = partitionLengthInSpill; + final FileChannel spillInputChannel = spillInputChannels[i]; + while (bytesToTransfer > 0) { + final long actualBytesTransferred = spillInputChannel.transferTo( spillInputChannelPositions[i], bytesToTransfer, mergedFileOutputChannel); - spillInputChannelPositions[i] += actualBytesTransferred; - bytesToTransfer -= actualBytesTransferred; + spillInputChannelPositions[i] += actualBytesTransferred; + bytesToTransfer -= actualBytesTransferred; + } + bytesWrittenToMergedFile += partitionLengthInSpill; + partitionLengths[partition] += partitionLengthInSpill; } - partitionLengths[partition] += partitionLengthInSpill; + } + // Check the position after transferTo loop to see if it is in the right position and raise an + // exception if it is incorrect. The position will not be increased to the expected length + // after calling transferTo in kernel version 2.6.32. This issue is described at + // https://bugs.openjdk.java.net/browse/JDK-7052359 and SPARK-3948. + if (mergedFileOutputChannel.position() != bytesWrittenToMergedFile) { + throw new IOException( + "Current position " + mergedFileOutputChannel.position() + " does not equal expected " + + "position " + bytesWrittenToMergedFile + " after transferTo. Please check your kernel" + + " version to see if it is 2.6.32, as there is a kernel bug which will lead to " + + "unexpected behavior when using transferTo. You can set spark.file.transferTo=false " + + "to disable this NIO feature." + ); + } + } finally { + for (int i = 0; i < spills.length; i++) { + assert(spillInputChannelPositions[i] == spills[i].file.length()); + if (spillInputChannels[i] != null) { + spillInputChannels[i].close(); + } + } + if (mergedFileOutputChannel != null) { + mergedFileOutputChannel.close(); } } - - // TODO: should this be in a finally block? - for (int i = 0; i < spills.length; i++) { - assert(spillInputChannelPositions[i] == spills[i].file.length()); - spillInputChannels[i].close(); - } - mergedFileOutputChannel.close(); - return partitionLengths; } @@ -215,6 +300,9 @@ public Option stop(boolean success) { stopping = true; freeMemory(); if (success) { + if (mapStatus == null) { + throw new IllegalStateException("Cannot call stop(true) without having called write()"); + } return Option.apply(mapStatus); } else { // The map task failed, so delete our output data. diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java index 55c447327ef35..b2eb68ce9dfea 100644 --- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java @@ -18,21 +18,25 @@ package org.apache.spark.shuffle.unsafe; import java.io.File; +import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; -import java.util.ArrayList; -import java.util.UUID; +import java.util.*; import scala.*; import scala.runtime.AbstractFunction1; +import org.junit.After; import org.junit.Assert; import org.junit.Before; import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; import static org.mockito.AdditionalAnswers.returnsFirstArg; import static org.mockito.AdditionalAnswers.returnsSecondArg; +import static org.mockito.Answers.RETURNS_SMART_NULLS; import static org.mockito.Mockito.*; import org.apache.spark.*; @@ -52,18 +56,21 @@ public class UnsafeShuffleWriterSuite { + static final int NUM_PARTITITONS = 4; final TaskMemoryManager memoryManager = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); - // Compute key prefixes based on the records' partition ids - final HashPartitioner hashPartitioner = new HashPartitioner(4); - - ShuffleMemoryManager shuffleMemoryManager; - BlockManager blockManager; - IndexShuffleBlockManager shuffleBlockManager; - DiskBlockManager diskBlockManager; + final HashPartitioner hashPartitioner = new HashPartitioner(NUM_PARTITITONS); + File mergedOutputFile; File tempDir; - TaskContext taskContext; - SparkConf sparkConf; + long[] partitionSizesInMergedFile; + final LinkedList spillFilesCreated = new LinkedList(); + + @Mock(answer = RETURNS_SMART_NULLS) ShuffleMemoryManager shuffleMemoryManager; + @Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager; + @Mock(answer = RETURNS_SMART_NULLS) IndexShuffleBlockManager shuffleBlockManager; + @Mock(answer = RETURNS_SMART_NULLS) DiskBlockManager diskBlockManager; + @Mock(answer = RETURNS_SMART_NULLS) TaskContext taskContext; + @Mock(answer = RETURNS_SMART_NULLS) ShuffleDependency shuffleDep; private static final class CompressStream extends AbstractFunction1 { @Override @@ -72,26 +79,23 @@ public OutputStream apply(OutputStream stream) { } } + @After + public void tearDown() { + Utils.deleteRecursively(tempDir); + } + @Before - public void setUp() { - shuffleMemoryManager = mock(ShuffleMemoryManager.class); - diskBlockManager = mock(DiskBlockManager.class); - blockManager = mock(BlockManager.class); - shuffleBlockManager = mock(IndexShuffleBlockManager.class); - tempDir = new File(Utils.createTempDir$default$1()); - taskContext = mock(TaskContext.class); - sparkConf = new SparkConf(); - when(taskContext.taskMetrics()).thenReturn(new TaskMetrics()); + @SuppressWarnings("unchecked") + public void setUp() throws IOException { + MockitoAnnotations.initMocks(this); + tempDir = Utils.createTempDir("test", "test"); + mergedOutputFile = File.createTempFile("mergedoutput", "", tempDir); + partitionSizesInMergedFile = null; + spillFilesCreated.clear(); + when(shuffleMemoryManager.tryToAcquire(anyLong())).then(returnsFirstArg()); + when(blockManager.diskBlockManager()).thenReturn(diskBlockManager); - when(diskBlockManager.createTempLocalBlock()).thenAnswer(new Answer>() { - @Override - public Tuple2 answer(InvocationOnMock invocationOnMock) throws Throwable { - TempLocalBlockId blockId = new TempLocalBlockId(UUID.randomUUID()); - File file = File.createTempFile("spillFile", ".spill", tempDir); - return Tuple2$.MODULE$.apply(blockId, file); - } - }); when(blockManager.getDiskWriter( any(BlockId.class), any(File.class), @@ -115,64 +119,103 @@ public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Th }); when(blockManager.wrapForCompression(any(BlockId.class), any(InputStream.class))) .then(returnsSecondArg()); - } - - @Test - @SuppressWarnings("unchecked") - public void basicShuffleWriting() throws Exception { - final ShuffleDependency dep = mock(ShuffleDependency.class); - when(dep.serializer()).thenReturn(Option.apply(new KryoSerializer(sparkConf))); - when(dep.partitioner()).thenReturn(hashPartitioner); - - final File mergedOutputFile = File.createTempFile("mergedoutput", "", tempDir); when(shuffleBlockManager.getDataFile(anyInt(), anyInt())).thenReturn(mergedOutputFile); - final long[] partitionSizes = new long[hashPartitioner.numPartitions()]; doAnswer(new Answer() { @Override public Void answer(InvocationOnMock invocationOnMock) throws Throwable { - long[] receivedPartitionSizes = (long[]) invocationOnMock.getArguments()[2]; - System.arraycopy( - receivedPartitionSizes, 0, partitionSizes, 0, receivedPartitionSizes.length); + partitionSizesInMergedFile = (long[]) invocationOnMock.getArguments()[2]; return null; } }).when(shuffleBlockManager).writeIndexFile(anyInt(), anyInt(), any(long[].class)); - final UnsafeShuffleWriter writer = new UnsafeShuffleWriter( + when(diskBlockManager.createTempShuffleBlock()).thenAnswer( + new Answer>() { + @Override + public Tuple2 answer( + InvocationOnMock invocationOnMock) throws Throwable { + TempLocalBlockId blockId = new TempLocalBlockId(UUID.randomUUID()); + File file = File.createTempFile("spillFile", ".spill", tempDir); + spillFilesCreated.add(file); + return Tuple2$.MODULE$.apply(blockId, file); + } + }); + + when(taskContext.taskMetrics()).thenReturn(new TaskMetrics()); + + when(shuffleDep.serializer()).thenReturn( + Option.apply(new KryoSerializer(new SparkConf()))); + when(shuffleDep.partitioner()).thenReturn(hashPartitioner); + } + + private UnsafeShuffleWriter createWriter(boolean transferToEnabled) { + SparkConf conf = new SparkConf(); + conf.set("spark.file.transferTo", String.valueOf(transferToEnabled)); + return new UnsafeShuffleWriter( blockManager, shuffleBlockManager, memoryManager, shuffleMemoryManager, - new UnsafeShuffleHandle(0, 1, dep), + new UnsafeShuffleHandle(0, 1, shuffleDep), 0, // map id taskContext, - sparkConf + new SparkConf() ); + } - final ArrayList> numbersToSort = - new ArrayList>(); - numbersToSort.add(new Tuple2(5, 5)); - numbersToSort.add(new Tuple2(1, 1)); - numbersToSort.add(new Tuple2(3, 3)); - numbersToSort.add(new Tuple2(2, 2)); - numbersToSort.add(new Tuple2(4, 4)); + private void assertSpillFilesWereCleanedUp() { + for (File spillFile : spillFilesCreated) { + Assert.assertFalse("Spill file " + spillFile.getPath() + " was not cleaned up", + spillFile.exists()); + } + } + + @Test(expected=IllegalStateException.class) + public void mustCallWriteBeforeSuccessfulStop() { + createWriter(false).stop(true); + } + + @Test + public void doNotNeedToCallWriteBeforeUnsuccessfulStop() { + createWriter(false).stop(false); + } + @Test + public void writeEmptyIterator() throws Exception { + final UnsafeShuffleWriter writer = createWriter(true); + writer.write(Collections.>emptyIterator()); + final Option mapStatus = writer.stop(true); + Assert.assertTrue(mapStatus.isDefined()); + Assert.assertTrue(mergedOutputFile.exists()); + Assert.assertArrayEquals(new long[NUM_PARTITITONS], partitionSizesInMergedFile); + } - writer.write(numbersToSort.iterator()); + @Test + public void writeWithoutSpilling() throws Exception { + // In this example, each partition should have exactly one record: + final ArrayList> datatToWrite = + new ArrayList>(); + for (int i = 0; i < NUM_PARTITITONS; i++) { + datatToWrite.add(new Tuple2(i, i)); + } + final UnsafeShuffleWriter writer = createWriter(true); + writer.write(datatToWrite.iterator()); final Option mapStatus = writer.stop(true); Assert.assertTrue(mapStatus.isDefined()); + Assert.assertTrue(mergedOutputFile.exists()); long sumOfPartitionSizes = 0; - for (long size: partitionSizes) { + for (long size: partitionSizesInMergedFile) { + // All partitions should be the same size: + Assert.assertEquals(partitionSizesInMergedFile[0], size); sumOfPartitionSizes += size; } Assert.assertEquals(mergedOutputFile.length(), sumOfPartitionSizes); - // TODO: actually try to read the shuffle output? - - // TODO: add a test that manually triggers spills in order to exercise the merging. - - // TODO: test that the temporary spill files were cleaned up after the merge. + assertSpillFilesWereCleanedUp(); } + // TODO: actually try to read the shuffle output? + // TODO: add a test that manually triggers spills in order to exercise the merging. + }