From 7cd013be50add13d076c372534beaf2ff7aa3f31 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 8 May 2015 17:26:39 -0700 Subject: [PATCH] Begin refactoring to enable proper tests for spilling. --- .../unsafe/UnsafeShuffleExternalSorter.java | 6 +- .../shuffle/unsafe/UnsafeShuffleWriter.java | 92 ++++++++++++------- .../unsafe/UnsafeShuffleWriterSuite.java | 41 ++++++++- 3 files changed, 102 insertions(+), 37 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java index 70c911252fddb..8f7c3b4232691 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java @@ -21,9 +21,9 @@ import java.io.IOException; import java.util.LinkedList; -import org.apache.spark.storage.*; import scala.Tuple2; +import com.google.common.annotations.VisibleForTesting; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -32,6 +32,7 @@ import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.serializer.SerializerInstance; import org.apache.spark.shuffle.ShuffleMemoryManager; +import org.apache.spark.storage.*; import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.memory.TaskMemoryManager; @@ -215,7 +216,8 @@ private SpillInfo writeSpillFile() throws IOException { /** * Sort and spill the current records in response to memory pressure. */ - private void spill() throws IOException { + @VisibleForTesting + void spill() throws IOException { final long threadId = Thread.currentThread().getId(); logger.info("Thread " + threadId + " spilling sort data of " + org.apache.spark.util.Utils.bytesToString(getMemoryUsage()) + " to disk (" + diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java index 206812f8352d2..e5a942498ae00 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java @@ -32,6 +32,7 @@ import scala.reflect.ClassTag$; import com.esotericsoftware.kryo.io.ByteBufferOutputStream; +import com.google.common.annotations.VisibleForTesting; import com.google.common.io.ByteStreams; import com.google.common.io.Files; import org.slf4j.Logger; @@ -73,6 +74,11 @@ public class UnsafeShuffleWriter extends ShuffleWriter { private final boolean transferToEnabled; private MapStatus mapStatus = null; + private UnsafeShuffleExternalSorter sorter = null; + private byte[] serArray = null; + private ByteBuffer serByteBuffer; + // TODO: we should not depend on this class from Kryo; copy its source or find an alternative + private SerializationStream serOutputStream; /** * Are we in the process of stopping? Because map tasks can call stop() with success = true @@ -113,25 +119,18 @@ public void write(Iterator> records) { @Override public void write(scala.collection.Iterator> records) { try { - final long[] partitionLengths = mergeSpills(insertRecordsIntoSorter(records)); - shuffleBlockManager.writeIndexFile(shuffleId, mapId, partitionLengths); - mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); + while (records.hasNext()) { + insertRecordIntoSorter(records.next()); + } + closeAndWriteOutput(); } catch (Exception e) { PlatformDependent.throwException(e); } } - private void freeMemory() { - // TODO - } - - private void deleteSpills() { - // TODO - } - - private SpillInfo[] insertRecordsIntoSorter( - scala.collection.Iterator> records) throws Exception { - final UnsafeShuffleExternalSorter sorter = new UnsafeShuffleExternalSorter( + private void open() throws IOException { + assert (sorter == null); + sorter = new UnsafeShuffleExternalSorter( memoryManager, shuffleMemoryManager, blockManager, @@ -139,30 +138,53 @@ private SpillInfo[] insertRecordsIntoSorter( 4096, // Initial size (TODO: tune this!) partitioner.numPartitions(), sparkConf); - - final byte[] serArray = new byte[SER_BUFFER_SIZE]; - final ByteBuffer serByteBuffer = ByteBuffer.wrap(serArray); + serArray = new byte[SER_BUFFER_SIZE]; + serByteBuffer = ByteBuffer.wrap(serArray); // TODO: we should not depend on this class from Kryo; copy its source or find an alternative - final SerializationStream serOutputStream = - serializer.serializeStream(new ByteBufferOutputStream(serByteBuffer)); + serOutputStream = serializer.serializeStream(new ByteBufferOutputStream(serByteBuffer)); + } - while (records.hasNext()) { - final Product2 record = records.next(); - final K key = record._1(); - final int partitionId = partitioner.getPartition(key); - serByteBuffer.position(0); - serOutputStream.writeKey(key, OBJECT_CLASS_TAG); - serOutputStream.writeValue(record._2(), OBJECT_CLASS_TAG); - serOutputStream.flush(); + @VisibleForTesting + void closeAndWriteOutput() throws IOException { + if (sorter == null) { + open(); + } + serArray = null; + serByteBuffer = null; + serOutputStream = null; + final long[] partitionLengths = mergeSpills(sorter.closeAndGetSpills()); + sorter = null; + shuffleBlockManager.writeIndexFile(shuffleId, mapId, partitionLengths); + mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); + } - final int serializedRecordSize = serByteBuffer.position(); - assert (serializedRecordSize > 0); + private void freeMemory() { + // TODO + } - sorter.insertRecord( - serArray, PlatformDependent.BYTE_ARRAY_OFFSET, serializedRecordSize, partitionId); + @VisibleForTesting + void insertRecordIntoSorter(Product2 record) throws IOException{ + if (sorter == null) { + open(); } + final K key = record._1(); + final int partitionId = partitioner.getPartition(key); + serByteBuffer.position(0); + serOutputStream.writeKey(key, OBJECT_CLASS_TAG); + serOutputStream.writeValue(record._2(), OBJECT_CLASS_TAG); + serOutputStream.flush(); - return sorter.closeAndGetSpills(); + final int serializedRecordSize = serByteBuffer.position(); + assert (serializedRecordSize > 0); + + sorter.insertRecord( + serArray, PlatformDependent.BYTE_ARRAY_OFFSET, serializedRecordSize, partitionId); + } + + @VisibleForTesting + void forceSorterToSpill() throws IOException { + assert (sorter != null); + sorter.spill(); } private long[] mergeSpills(SpillInfo[] spills) throws IOException { @@ -222,6 +244,9 @@ private long[] mergeSpillsWithFileStream(SpillInfo[] spills, File outputFile) th for (int i = 0; i < spills.length; i++) { if (spillInputStreams[i] != null) { spillInputStreams[i].close(); + if (!spills[i].file.delete()) { + logger.error("Error while deleting spill file {}", spills[i]); + } } } if (mergedFileOutputStream != null) { @@ -282,6 +307,9 @@ private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) th assert(spillInputChannelPositions[i] == spills[i].file.length()); if (spillInputChannels[i] != null) { spillInputChannels[i].close(); + if (!spills[i].file.delete()) { + logger.error("Error while deleting spill file {}", spills[i]); + } } } if (mergedFileOutputChannel != null) { diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java index b2eb68ce9dfea..09eb537c04367 100644 --- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java @@ -193,13 +193,13 @@ public void writeEmptyIterator() throws Exception { @Test public void writeWithoutSpilling() throws Exception { // In this example, each partition should have exactly one record: - final ArrayList> datatToWrite = + final ArrayList> dataToWrite = new ArrayList>(); for (int i = 0; i < NUM_PARTITITONS; i++) { - datatToWrite.add(new Tuple2(i, i)); + dataToWrite.add(new Tuple2(i, i)); } final UnsafeShuffleWriter writer = createWriter(true); - writer.write(datatToWrite.iterator()); + writer.write(dataToWrite.iterator()); final Option mapStatus = writer.stop(true); Assert.assertTrue(mapStatus.isDefined()); Assert.assertTrue(mergedOutputFile.exists()); @@ -215,7 +215,42 @@ public void writeWithoutSpilling() throws Exception { assertSpillFilesWereCleanedUp(); } + private void testMergingSpills(boolean transferToEnabled) throws IOException { + final UnsafeShuffleWriter writer = createWriter(true); + writer.insertRecordIntoSorter(new Tuple2(1, 1)); + writer.insertRecordIntoSorter(new Tuple2(2, 2)); + writer.insertRecordIntoSorter(new Tuple2(3, 3)); + writer.insertRecordIntoSorter(new Tuple2(4, 4)); + writer.forceSorterToSpill(); + writer.insertRecordIntoSorter(new Tuple2(4, 4)); + writer.insertRecordIntoSorter(new Tuple2(2, 2)); + writer.closeAndWriteOutput(); + final Option mapStatus = writer.stop(true); + Assert.assertTrue(mapStatus.isDefined()); + Assert.assertTrue(mergedOutputFile.exists()); + Assert.assertEquals(2, spillFilesCreated.size()); + + long sumOfPartitionSizes = 0; + for (long size: partitionSizesInMergedFile) { + sumOfPartitionSizes += size; + } + Assert.assertEquals(mergedOutputFile.length(), sumOfPartitionSizes); + + assertSpillFilesWereCleanedUp(); + } + + @Test + public void mergeSpillsWithTransferTo() throws Exception { + testMergingSpills(true); + } + + @Test + public void mergeSpillsWithFileStream() throws Exception { + testMergingSpills(false); + } + // TODO: actually try to read the shuffle output? // TODO: add a test that manually triggers spills in order to exercise the merging. +// } }