diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java index 5435c2c98428f..5d13354231491 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java @@ -17,7 +17,7 @@ package org.apache.spark.shuffle.unsafe; -import org.apache.spark.storage.BlockId; +import org.apache.spark.storage.TempShuffleBlockId; import java.io.File; @@ -27,9 +27,9 @@ final class SpillInfo { final long[] partitionLengths; final File file; - final BlockId blockId; + final TempShuffleBlockId blockId; - public SpillInfo(int numPartitions, File file, BlockId blockId) { + public SpillInfo(int numPartitions, File file, TempShuffleBlockId blockId) { this.partitionLengths = new long[numPartitions]; this.file = file; this.blockId = blockId; diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java index ccc1018a71168..3cf99307c47cc 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java @@ -153,7 +153,7 @@ private SpillInfo writeSpillFile() throws IOException { final Tuple2 spilledFileInfo = blockManager.diskBlockManager().createTempShuffleBlock(); final File file = spilledFileInfo._2(); - final BlockId blockId = spilledFileInfo._1(); + final TempShuffleBlockId blockId = spilledFileInfo._1(); final SpillInfo spillInfo = new SpillInfo(numPartitions, file, blockId); // Unfortunately, we need a serializer instance in order to construct a DiskBlockObjectWriter. @@ -320,7 +320,7 @@ private void allocateSpaceForRecord(int requiredSpace) throws IOException { } } if (requiredSpace > freeSpaceInCurrentPage) { - logger.debug("Required space {} is less than free space in current page ({}}", requiredSpace, + logger.trace("Required space {} is less than free space in current page ({})", requiredSpace, freeSpaceInCurrentPage); // TODO: we should track metrics on the amount of space wasted when we roll over to a new page // without using the free space at the end of the current page. We should also do this for diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java index 5bf04617854bb..df05a95506f4b 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java @@ -20,6 +20,7 @@ import java.io.*; import java.nio.channels.FileChannel; import java.util.Iterator; +import javax.annotation.Nullable; import scala.Option; import scala.Product2; @@ -35,6 +36,9 @@ import org.slf4j.LoggerFactory; import org.apache.spark.*; +import org.apache.spark.io.CompressionCodec; +import org.apache.spark.io.CompressionCodec$; +import org.apache.spark.io.LZFCompressionCodec; import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.network.util.LimitedInputStream; import org.apache.spark.scheduler.MapStatus; @@ -53,8 +57,6 @@ public class UnsafeShuffleWriter extends ShuffleWriter { private final Logger logger = LoggerFactory.getLogger(UnsafeShuffleWriter.class); - @VisibleForTesting - static final int MAXIMUM_RECORD_SIZE = 1024 * 1024 * 64; // 64 megabytes private static final ClassTag OBJECT_CLASS_TAG = ClassTag$.MODULE$.Object(); private final BlockManager blockManager; @@ -201,6 +203,12 @@ void forceSorterToSpill() throws IOException { private long[] mergeSpills(SpillInfo[] spills) throws IOException { final File outputFile = shuffleBlockResolver.getDataFile(shuffleId, mapId); + final boolean compressionEnabled = sparkConf.getBoolean("spark.shuffle.compress", true); + final CompressionCodec compressionCodec = CompressionCodec$.MODULE$.createCodec(sparkConf); + final boolean fastMergeEnabled = + sparkConf.getBoolean("spark.shuffle.unsafe.fastMergeEnabled", true); + final boolean fastMergeIsSupported = + !compressionEnabled || compressionCodec instanceof LZFCompressionCodec; try { if (spills.length == 0) { new FileOutputStream(outputFile).close(); // Create an empty file @@ -215,11 +223,20 @@ private long[] mergeSpills(SpillInfo[] spills) throws IOException { Files.move(spills[0].file, outputFile); return spills[0].partitionLengths; } else { - // Need to merge multiple spills. - if (transferToEnabled) { - return mergeSpillsWithTransferTo(spills, outputFile); + if (fastMergeEnabled && fastMergeIsSupported) { + // Compression is disabled or we are using an IO compression codec that supports + // decompression of concatenated compressed streams, so we can perform a fast spill merge + // that doesn't need to interpret the spilled bytes. + if (transferToEnabled) { + logger.debug("Using transferTo-based fast merge"); + return mergeSpillsWithTransferTo(spills, outputFile); + } else { + logger.debug("Using fileStream-based fast merge"); + return mergeSpillsWithFileStream(spills, outputFile, null); + } } else { - return mergeSpillsWithFileStream(spills, outputFile); + logger.debug("Using slow merge"); + return mergeSpillsWithFileStream(spills, outputFile, compressionCodec); } } } catch (IOException e) { @@ -230,27 +247,40 @@ private long[] mergeSpills(SpillInfo[] spills) throws IOException { } } - private long[] mergeSpillsWithFileStream(SpillInfo[] spills, File outputFile) throws IOException { + private long[] mergeSpillsWithFileStream( + SpillInfo[] spills, + File outputFile, + @Nullable CompressionCodec compressionCodec) throws IOException { final int numPartitions = partitioner.numPartitions(); final long[] partitionLengths = new long[numPartitions]; - final FileInputStream[] spillInputStreams = new FileInputStream[spills.length]; - FileOutputStream mergedFileOutputStream = null; + final InputStream[] spillInputStreams = new FileInputStream[spills.length]; + OutputStream mergedFileOutputStream = null; try { for (int i = 0; i < spills.length; i++) { spillInputStreams[i] = new FileInputStream(spills[i].file); } - mergedFileOutputStream = new FileOutputStream(outputFile); - for (int partition = 0; partition < numPartitions; partition++) { + final long initialFileLength = outputFile.length(); + mergedFileOutputStream = new FileOutputStream(outputFile, true); + if (compressionCodec != null) { + mergedFileOutputStream = compressionCodec.compressedOutputStream(mergedFileOutputStream); + } + for (int i = 0; i < spills.length; i++) { final long partitionLengthInSpill = spills[i].partitionLengths[partition]; - final FileInputStream spillInputStream = spillInputStreams[i]; - ByteStreams.copy - (new LimitedInputStream(spillInputStream, partitionLengthInSpill), - mergedFileOutputStream); - partitionLengths[partition] += partitionLengthInSpill; + if (partitionLengthInSpill > 0) { + InputStream partitionInputStream = + new LimitedInputStream(spillInputStreams[i], partitionLengthInSpill); + if (compressionCodec != null) { + partitionInputStream = compressionCodec.compressedInputStream(partitionInputStream); + } + ByteStreams.copy(partitionInputStream, mergedFileOutputStream); + } } + mergedFileOutputStream.flush(); + mergedFileOutputStream.close(); + partitionLengths[partition] = (outputFile.length() - initialFileLength); } } finally { for (InputStream stream : spillInputStreams) { diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java index 48ba85f917b87..511fdfa43d543 100644 --- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java @@ -37,11 +37,14 @@ import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; import static org.mockito.AdditionalAnswers.returnsFirstArg; -import static org.mockito.AdditionalAnswers.returnsSecondArg; import static org.mockito.Answers.RETURNS_SMART_NULLS; import static org.mockito.Mockito.*; import org.apache.spark.*; +import org.apache.spark.io.CompressionCodec$; +import org.apache.spark.io.LZ4CompressionCodec; +import org.apache.spark.io.LZFCompressionCodec; +import org.apache.spark.io.SnappyCompressionCodec; import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.executor.TaskMetrics; import org.apache.spark.network.util.LimitedInputStream; @@ -65,6 +68,7 @@ public class UnsafeShuffleWriterSuite { File tempDir; long[] partitionSizesInMergedFile; final LinkedList spillFilesCreated = new LinkedList(); + SparkConf conf; final Serializer serializer = new KryoSerializer(new SparkConf()); @Mock(answer = RETURNS_SMART_NULLS) ShuffleMemoryManager shuffleMemoryManager; @@ -74,10 +78,14 @@ public class UnsafeShuffleWriterSuite { @Mock(answer = RETURNS_SMART_NULLS) TaskContext taskContext; @Mock(answer = RETURNS_SMART_NULLS) ShuffleDependency shuffleDep; - private static final class CompressStream extends AbstractFunction1 { + private final class CompressStream extends AbstractFunction1 { @Override public OutputStream apply(OutputStream stream) { - return stream; + if (conf.getBoolean("spark.shuffle.compress", true)) { + return CompressionCodec$.MODULE$.createCodec(conf).compressedOutputStream(stream); + } else { + return stream; + } } } @@ -98,6 +106,7 @@ public void setUp() throws IOException { mergedOutputFile = File.createTempFile("mergedoutput", "", tempDir); partitionSizesInMergedFile = null; spillFilesCreated.clear(); + conf = new SparkConf(); when(shuffleMemoryManager.tryToAcquire(anyLong())).then(returnsFirstArg()); @@ -123,8 +132,35 @@ public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Th ); } }); - when(blockManager.wrapForCompression(any(BlockId.class), any(InputStream.class))) - .then(returnsSecondArg()); + when(blockManager.wrapForCompression(any(BlockId.class), any(InputStream.class))).thenAnswer( + new Answer() { + @Override + public InputStream answer(InvocationOnMock invocation) throws Throwable { + assert (invocation.getArguments()[0] instanceof TempShuffleBlockId); + InputStream is = (InputStream) invocation.getArguments()[1]; + if (conf.getBoolean("spark.shuffle.compress", true)) { + return CompressionCodec$.MODULE$.createCodec(conf).compressedInputStream(is); + } else { + return is; + } + } + } + ); + + when(blockManager.wrapForCompression(any(BlockId.class), any(OutputStream.class))).thenAnswer( + new Answer() { + @Override + public OutputStream answer(InvocationOnMock invocation) throws Throwable { + assert (invocation.getArguments()[0] instanceof TempShuffleBlockId); + OutputStream os = (OutputStream) invocation.getArguments()[1]; + if (conf.getBoolean("spark.shuffle.compress", true)) { + return CompressionCodec$.MODULE$.createCodec(conf).compressedOutputStream(os); + } else { + return os; + } + } + } + ); when(shuffleBlockResolver.getDataFile(anyInt(), anyInt())).thenReturn(mergedOutputFile); doAnswer(new Answer() { @@ -136,11 +172,11 @@ public Void answer(InvocationOnMock invocationOnMock) throws Throwable { }).when(shuffleBlockResolver).writeIndexFile(anyInt(), anyInt(), any(long[].class)); when(diskBlockManager.createTempShuffleBlock()).thenAnswer( - new Answer>() { + new Answer>() { @Override - public Tuple2 answer( + public Tuple2 answer( InvocationOnMock invocationOnMock) throws Throwable { - TempLocalBlockId blockId = new TempLocalBlockId(UUID.randomUUID()); + TempShuffleBlockId blockId = new TempShuffleBlockId(UUID.randomUUID()); File file = File.createTempFile("spillFile", ".spill", tempDir); spillFilesCreated.add(file); return Tuple2$.MODULE$.apply(blockId, file); @@ -154,7 +190,6 @@ public Tuple2 answer( } private UnsafeShuffleWriter createWriter(boolean transferToEnabled) { - SparkConf conf = new SparkConf(); conf.set("spark.file.transferTo", String.valueOf(transferToEnabled)); return new UnsafeShuffleWriter( blockManager, @@ -164,7 +199,7 @@ private UnsafeShuffleWriter createWriter(boolean transferToEnabl new UnsafeShuffleHandle(0, 1, shuffleDep), 0, // map id taskContext, - new SparkConf() + conf ); } @@ -183,8 +218,11 @@ private List> readRecordsFromFile() throws IOException { if (partitionSize > 0) { InputStream in = new FileInputStream(mergedOutputFile); ByteStreams.skipFully(in, startOffset); - DeserializationStream recordsStream = serializer.newInstance().deserializeStream( - new LimitedInputStream(in, partitionSize)); + in = new LimitedInputStream(in, partitionSize); + if (conf.getBoolean("spark.shuffle.compress", true)) { + in = CompressionCodec$.MODULE$.createCodec(conf).compressedInputStream(in); + } + DeserializationStream recordsStream = serializer.newInstance().deserializeStream(in); Iterator> records = recordsStream.asKeyValueIterator(); while (records.hasNext()) { Tuple2 record = records.next(); @@ -245,7 +283,15 @@ public void writeWithoutSpilling() throws Exception { assertSpillFilesWereCleanedUp(); } - private void testMergingSpills(boolean transferToEnabled) throws IOException { + private void testMergingSpills( + boolean transferToEnabled, + String compressionCodecName) throws IOException { + if (compressionCodecName != null) { + conf.set("spark.shuffle.compress", "true"); + conf.set("spark.io.compression.codec", compressionCodecName); + } else { + conf.set("spark.shuffle.compress", "false"); + } final UnsafeShuffleWriter writer = createWriter(transferToEnabled); final ArrayList> dataToWrite = new ArrayList>(); @@ -265,11 +311,13 @@ private void testMergingSpills(boolean transferToEnabled) throws IOException { Assert.assertTrue(mergedOutputFile.exists()); Assert.assertEquals(2, spillFilesCreated.size()); - long sumOfPartitionSizes = 0; - for (long size: partitionSizesInMergedFile) { - sumOfPartitionSizes += size; - } - Assert.assertEquals(mergedOutputFile.length(), sumOfPartitionSizes); + // This assertion only holds for the fast merging path: + // long sumOfPartitionSizes = 0; + // for (long size: partitionSizesInMergedFile) { + // sumOfPartitionSizes += size; + // } + // Assert.assertEquals(sumOfPartitionSizes, mergedOutputFile.length()); + Assert.assertTrue(mergedOutputFile.length() > 0); Assert.assertEquals( HashMultiset.create(dataToWrite), HashMultiset.create(readRecordsFromFile())); @@ -277,13 +325,43 @@ private void testMergingSpills(boolean transferToEnabled) throws IOException { } @Test - public void mergeSpillsWithTransferTo() throws Exception { - testMergingSpills(true); + public void mergeSpillsWithTransferToAndLZF() throws Exception { + testMergingSpills(true, LZFCompressionCodec.class.getName()); + } + + @Test + public void mergeSpillsWithFileStreamAndLZF() throws Exception { + testMergingSpills(false, LZFCompressionCodec.class.getName()); + } + + @Test + public void mergeSpillsWithTransferToAndLZ4() throws Exception { + testMergingSpills(true, LZ4CompressionCodec.class.getName()); + } + + @Test + public void mergeSpillsWithFileStreamAndLZ4() throws Exception { + testMergingSpills(false, LZ4CompressionCodec.class.getName()); + } + + @Test + public void mergeSpillsWithTransferToAndSnappy() throws Exception { + testMergingSpills(true, SnappyCompressionCodec.class.getName()); + } + + @Test + public void mergeSpillsWithFileStreamAndSnappy() throws Exception { + testMergingSpills(false, SnappyCompressionCodec.class.getName()); + } + + @Test + public void mergeSpillsWithTransferToAndNoCompression() throws Exception { + testMergingSpills(true, null); } @Test - public void mergeSpillsWithFileStream() throws Exception { - testMergingSpills(false); + public void mergeSpillsWithFileStreamAndNoCompression() throws Exception { + testMergingSpills(false, null); } @Test