From 4f70141aa6949d9251719790551e184cac66a05b Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 5 May 2015 11:36:01 -0700 Subject: [PATCH] Fix merging; now passes UnsafeShuffleSuite tests. --- .../unsafe/UnsafeShuffleSpillWriter.java | 7 +++-- .../shuffle/unsafe/UnsafeShuffleWriter.java | 26 ++++++++++++++----- .../unsafe/UnsafeShuffleWriterSuite.java | 7 ++++- 3 files changed, 30 insertions(+), 10 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSpillWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSpillWriter.java index 8e0a21ec6b3a5..fd2c170bd2e41 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSpillWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSpillWriter.java @@ -118,7 +118,6 @@ private SpillInfo writeSpillFile() throws IOException { final File file = spilledFileInfo._2(); final BlockId blockId = spilledFileInfo._1(); final SpillInfo spillInfo = new SpillInfo(numPartitions, file, blockId); - spills.add(spillInfo); final SerializerInstance ser = new DummySerializerInstance(); writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, writeMetrics); @@ -154,7 +153,11 @@ private SpillInfo writeSpillFile() throws IOException { if (writer != null) { writer.commitAndClose(); - spillInfo.partitionLengths[currentPartition] = writer.fileSegment().length(); + // TODO: comment and explain why our handling of empty spills, etc. + if (currentPartition != -1) { + spillInfo.partitionLengths[currentPartition] = writer.fileSegment().length(); + spills.add(spillInfo); + } } return spillInfo; } diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java index 47fe214634abb..ad842502bf24f 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java @@ -157,7 +157,14 @@ private long[] mergeSpills(SpillInfo[] spills) throws IOException { final File outputFile = shuffleBlockManager.getDataFile(shuffleId, mapId); final int numPartitions = partitioner.numPartitions(); final long[] partitionLengths = new long[numPartitions]; + + if (spills.length == 0) { + new FileOutputStream(outputFile).close(); + return partitionLengths; + } + final FileChannel[] spillInputChannels = new FileChannel[spills.length]; + final long[] spillInputChannelPositions = new long[spills.length]; // TODO: We need to add an option to bypass transferTo here since older Linux kernels are // affected by a bug here that can lead to data truncation; see the comments Utils.scala, @@ -173,24 +180,29 @@ private long[] mergeSpills(SpillInfo[] spills) throws IOException { final FileChannel mergedFileOutputChannel = new FileOutputStream(outputFile).getChannel(); - for (int partition = 0; partition < numPartitions; partition++ ) { + for (int partition = 0; partition < numPartitions; partition++) { for (int i = 0; i < spills.length; i++) { - final long bytesToTransfer = spills[i].partitionLengths[partition]; - long bytesRemainingToBeTransferred = bytesToTransfer; + System.out.println("In partition " + partition + " and spill " + i ); + final long partitionLengthInSpill = spills[i].partitionLengths[partition]; + System.out.println("Partition length in spill is " + partitionLengthInSpill); + System.out.println("input channel position is " + spillInputChannels[i].position()); + long bytesRemainingToBeTransferred = partitionLengthInSpill; final FileChannel spillInputChannel = spillInputChannels[i]; - long fromPosition = spillInputChannel.position(); while (bytesRemainingToBeTransferred > 0) { - bytesRemainingToBeTransferred -= spillInputChannel.transferTo( - fromPosition, + final long actualBytesTransferred = spillInputChannel.transferTo( + spillInputChannelPositions[i], bytesRemainingToBeTransferred, mergedFileOutputChannel); + spillInputChannelPositions[i] += actualBytesTransferred; + bytesRemainingToBeTransferred -= actualBytesTransferred; } - partitionLengths[partition] += bytesToTransfer; + partitionLengths[partition] += partitionLengthInSpill; } } // TODO: should this be in a finally block? for (int i = 0; i < spills.length; i++) { + assert(spillInputChannelPositions[i] == spills[i].file.length()); spillInputChannels[i].close(); } mergedFileOutputChannel.close(); diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java index 8ba548420bd4b..9008cc2de9bd5 100644 --- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java @@ -158,7 +158,8 @@ public Void answer(InvocationOnMock invocationOnMock) throws Throwable { writer.write(numbersToSort.iterator()); - final MapStatus mapStatus = writer.stop(true).get(); + final Option mapStatus = writer.stop(true); + Assert.assertTrue(mapStatus.isDefined()); long sumOfPartitionSizes = 0; for (long size: partitionSizes) { @@ -166,6 +167,10 @@ public Void answer(InvocationOnMock invocationOnMock) throws Throwable { } Assert.assertEquals(mergedOutputFile.length(), sumOfPartitionSizes); + // TODO: actually try to read the shuffle output? + + // TODO: add a test that manually triggers spills in order to exercise the merging. + // TODO: test that the temporary spill files were cleaned up after the merge. }