From 2776acaf368616f96f9c42f79744293a70b7a08a Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 3 May 2015 21:20:47 -0700 Subject: [PATCH] First passing test for ExternalSorter. --- .../sort/UnsafeExternalSortSpillMerger.java | 10 +- .../unsafe/sort/UnsafeExternalSorter.java | 5 +- .../spark/unsafe/sort/UnsafeSorter.java | 10 +- .../unsafe/sort/UnsafeSorterSpillReader.java | 25 +- .../unsafe/sort/UnsafeSorterSpillWriter.java | 54 ++++- .../spark/storage/BlockObjectWriter.scala | 8 +- .../sort/UnsafeExternalSorterSuite.java | 218 ++++++++++-------- pom.xml | 2 +- 8 files changed, 209 insertions(+), 123 deletions(-) diff --git a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeExternalSortSpillMerger.java b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeExternalSortSpillMerger.java index 89928ffaa448d..c6bd4ee9df4ff 100644 --- a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeExternalSortSpillMerger.java +++ b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeExternalSortSpillMerger.java @@ -30,7 +30,7 @@ public final class UnsafeExternalSortSpillMerger { public static abstract class MergeableIterator { public abstract boolean hasNext(); - public abstract void advanceRecord(); + public abstract void loadNextRecord(); public abstract long getPrefix(); @@ -68,6 +68,9 @@ public int compare(MergeableIterator left, MergeableIterator right) { } public void addSpill(MergeableIterator spillReader) { + if (spillReader.hasNext()) { + spillReader.loadNextRecord(); + } priorityQueue.add(spillReader); } @@ -79,17 +82,18 @@ public Iterator getSortedIterator() { @Override public boolean hasNext() { - return spillReader.hasNext() || !priorityQueue.isEmpty(); + return !priorityQueue.isEmpty() || (spillReader != null && spillReader.hasNext()); } @Override public RecordAddressAndKeyPrefix next() { if (spillReader != null) { if (spillReader.hasNext()) { + spillReader.loadNextRecord(); priorityQueue.add(spillReader); } } - spillReader = priorityQueue.poll(); + spillReader = priorityQueue.remove(); record.baseObject = spillReader.getBaseObject(); record.baseOffset = spillReader.getBaseOffset(); record.keyPrefix = spillReader.getPrefix(); diff --git a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeExternalSorter.java index 613d07cf6a316..42669055bcf1c 100644 --- a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeExternalSorter.java @@ -38,7 +38,6 @@ public final class UnsafeExternalSorter { private static final int PAGE_SIZE = 1024 * 1024; // TODO: tune this - private static final int SER_BUFFER_SIZE = 1024 * 1024; // TODO: tune this private final PrefixComparator prefixComparator; private final RecordComparator recordComparator; @@ -92,6 +91,7 @@ private void openSorter() { public void spill() throws IOException { final UnsafeSorterSpillWriter spillWriter = new UnsafeSorterSpillWriter(blockManager, fileBufferSize, writeMetrics); + spillWriters.add(spillWriter); final Iterator sortedRecords = sorter.getSortedIterator(); while (sortedRecords.hasNext()) { final RecordPointerAndKeyPrefix recordPointer = sortedRecords.next(); @@ -110,8 +110,11 @@ private void freeMemory() { final Iterator iter = allocatedPages.iterator(); while (iter.hasNext()) { memoryManager.freePage(iter.next()); + shuffleMemoryManager.release(PAGE_SIZE); iter.remove(); } + currentPage = null; + currentPagePosition = -1; } private void ensureSpaceInDataPage(int requiredSpace) throws Exception { diff --git a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java index 1801585e2ed84..0f844c1997668 100644 --- a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java +++ b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java @@ -169,7 +169,8 @@ public void remove() { public UnsafeExternalSortSpillMerger.MergeableIterator getMergeableIterator() { sorter.sort(sortBuffer, 0, sortBufferInsertPosition / 2, sortComparator); - return new UnsafeExternalSortSpillMerger.MergeableIterator() { + UnsafeExternalSortSpillMerger.MergeableIterator iter = + new UnsafeExternalSortSpillMerger.MergeableIterator() { private int position = 0; private Object baseObject; @@ -182,12 +183,12 @@ public boolean hasNext() { } @Override - public void advanceRecord() { + public void loadNextRecord() { final long recordPointer = sortBuffer[position]; - baseObject = memoryManager.getPage(recordPointer); - baseOffset = memoryManager.getOffsetInPage(recordPointer); keyPrefix = sortBuffer[position + 1]; position += 2; + baseObject = memoryManager.getPage(recordPointer); + baseOffset = memoryManager.getOffsetInPage(recordPointer); } @Override @@ -205,5 +206,6 @@ public long getBaseOffset() { return baseOffset; } }; + return iter; } } diff --git a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillReader.java index e2d5e6a8faa10..7c696240aaa73 100644 --- a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillReader.java +++ b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillReader.java @@ -18,15 +18,9 @@ package org.apache.spark.unsafe.sort; import com.google.common.io.ByteStreams; -import org.apache.spark.executor.ShuffleWriteMetrics; -import org.apache.spark.serializer.JavaSerializerInstance; -import org.apache.spark.serializer.SerializerInstance; import org.apache.spark.storage.BlockId; import org.apache.spark.storage.BlockManager; -import org.apache.spark.storage.BlockObjectWriter; -import org.apache.spark.storage.TempLocalBlockId; import org.apache.spark.unsafe.PlatformDependent; -import scala.Tuple2; import java.io.*; @@ -39,6 +33,7 @@ public final class UnsafeSorterSpillReader extends UnsafeExternalSortSpillMerger private long keyPrefix; private final byte[] arr = new byte[1024 * 1024]; // TODO: tune this (maybe grow dynamically)? private final Object baseObject = arr; + private int nextRecordLength; private final long baseOffset = PlatformDependent.BYTE_ARRAY_OFFSET; public UnsafeSorterSpillReader( @@ -46,11 +41,11 @@ public UnsafeSorterSpillReader( File file, BlockId blockId) throws IOException { this.file = file; + assert (file.length() > 0); final BufferedInputStream bs = new BufferedInputStream(new FileInputStream(file)); this.in = blockManager.wrapForCompression(blockId, bs); this.din = new DataInputStream(this.in); - assert (file.length() > 0); - advanceRecord(); + nextRecordLength = din.readInt(); } @Override @@ -59,21 +54,19 @@ public boolean hasNext() { } @Override - public void advanceRecord() { + public void loadNextRecord() { try { - final int recordLength = din.readInt(); - if (recordLength == UnsafeSorterSpillWriter.EOF_MARKER) { + keyPrefix = din.readLong(); + ByteStreams.readFully(in, arr, 0, nextRecordLength); + nextRecordLength = din.readInt(); + if (nextRecordLength == UnsafeSorterSpillWriter.EOF_MARKER) { in.close(); in = null; - return; + din = null; } - keyPrefix = din.readLong(); - ByteStreams.readFully(in, arr, 0, recordLength); - } catch (Exception e) { PlatformDependent.throwException(e); } - throw new IllegalStateException(); } @Override diff --git a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillWriter.java b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillWriter.java index fdda38d3f1c47..e0649122ac09c 100644 --- a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillWriter.java +++ b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillWriter.java @@ -18,7 +18,9 @@ package org.apache.spark.unsafe.sort; import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.serializer.DeserializationStream; import org.apache.spark.serializer.JavaSerializerInstance; +import org.apache.spark.serializer.SerializationStream; import org.apache.spark.serializer.SerializerInstance; import org.apache.spark.storage.BlockId; import org.apache.spark.storage.BlockManager; @@ -26,10 +28,10 @@ import org.apache.spark.storage.TempLocalBlockId; import org.apache.spark.unsafe.PlatformDependent; import scala.Tuple2; +import scala.reflect.ClassTag; -import java.io.DataOutputStream; -import java.io.File; -import java.io.IOException; +import java.io.*; +import java.nio.ByteBuffer; public final class UnsafeSorterSpillWriter { @@ -51,7 +53,47 @@ public UnsafeSorterSpillWriter( this.file = spilledFileInfo._2(); this.blockId = spilledFileInfo._1(); // Dummy serializer: - final SerializerInstance ser = new JavaSerializerInstance(0, false, null); + final SerializerInstance ser = new SerializerInstance() { + @Override + public SerializationStream serializeStream(OutputStream s) { + return new SerializationStream() { + @Override + public void flush() { + + } + + @Override + public SerializationStream writeObject(T t, ClassTag ev1) { + return null; + } + + @Override + public void close() { + + } + }; + } + + @Override + public ByteBuffer serialize(T t, ClassTag ev1) { + return null; + } + + @Override + public DeserializationStream deserializeStream(InputStream s) { + return null; + } + + @Override + public T deserialize(ByteBuffer bytes, ClassLoader loader, ClassTag ev1) { + return null; + } + + @Override + public T deserialize(ByteBuffer bytes, ClassTag ev1) { + return null; + } + }; writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, writeMetrics); dos = new DataOutputStream(writer); } @@ -61,14 +103,14 @@ public void write( long baseOffset, int recordLength, long keyPrefix) throws IOException { + dos.writeInt(recordLength); + dos.writeLong(keyPrefix); PlatformDependent.copyMemory( baseObject, baseOffset + 4, arr, PlatformDependent.BYTE_ARRAY_OFFSET, recordLength); - dos.writeInt(recordLength); - dos.writeLong(keyPrefix); writer.write(arr, 0, recordLength); // TODO: add a test that detects whether we leave this call out: writer.recordWritten(); diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala index 499dd97c0656a..f273a31706cd8 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala @@ -223,7 +223,13 @@ private[spark] class DiskBlockObjectWriter( } } - override def write(b: Int): Unit = throw new UnsupportedOperationException() + override def write(b: Int): Unit = { + if (!initialized) { + open() + } + + bs.write(b) + } override def write(kvBytes: Array[Byte], offs: Int, len: Int): Unit = { if (!initialized) { diff --git a/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeExternalSorterSuite.java index e4376f1cea4fc..4f2aa9b895c01 100644 --- a/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeExternalSorterSuite.java @@ -19,29 +19,117 @@ import org.apache.spark.HashPartitioner; +import org.apache.spark.SparkConf; +import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.serializer.SerializerInstance; +import org.apache.spark.shuffle.ShuffleMemoryManager; +import org.apache.spark.storage.*; import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.unsafe.memory.ExecutorMemoryManager; import org.apache.spark.unsafe.memory.MemoryAllocator; -import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.memory.TaskMemoryManager; +import org.apache.spark.util.Utils; import org.junit.Assert; +import org.junit.Before; import org.junit.Test; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; +import scala.Tuple2; +import scala.Tuple2$; +import scala.runtime.AbstractFunction1; -import java.util.Arrays; +import java.io.File; +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.file.Files; import java.util.Iterator; +import java.util.UUID; -import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.*; +import static org.mockito.AdditionalAnswers.*; public class UnsafeExternalSorterSuite { - private static String getStringFromDataPage(Object baseObject, long baseOffset) { - final int strLength = (int) PlatformDependent.UNSAFE.getLong(baseObject, baseOffset); - final byte[] strBytes = new byte[strLength]; - PlatformDependent.copyMemory( - baseObject, - baseOffset + 8, - strBytes, - PlatformDependent.BYTE_ARRAY_OFFSET, strLength); - return new String(strBytes); + + final TaskMemoryManager memoryManager = + new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); + // Compute key prefixes based on the records' partition ids + final HashPartitioner hashPartitioner = new HashPartitioner(4); + // Use integer comparison for comparing prefixes (which are partition ids, in this case) + final UnsafeSorter.PrefixComparator prefixComparator = new UnsafeSorter.PrefixComparator() { + @Override + public int compare(long prefix1, long prefix2) { + return (int) prefix1 - (int) prefix2; + } + }; + // Since the key fits within the 8-byte prefix, we don't need to do any record comparison, so + // use a dummy comparator + final UnsafeSorter.RecordComparator recordComparator = new UnsafeSorter.RecordComparator() { + @Override + public int compare( + Object leftBaseObject, + long leftBaseOffset, + Object rightBaseObject, + long rightBaseOffset) { + return 0; + } + }; + + ShuffleMemoryManager shuffleMemoryManager; + BlockManager blockManager; + DiskBlockManager diskBlockManager; + File tempDir; + + private static final class CompressStream extends AbstractFunction1 { + @Override + public OutputStream apply(OutputStream stream) { + return stream; + } + } + + @Before + public void setUp() { + shuffleMemoryManager = mock(ShuffleMemoryManager.class); + diskBlockManager = mock(DiskBlockManager.class); + blockManager = mock(BlockManager.class); + tempDir = new File(Utils.createTempDir$default$1()); + when(shuffleMemoryManager.tryToAcquire(anyLong())).then(returnsFirstArg()); + when(blockManager.diskBlockManager()).thenReturn(diskBlockManager); + when(diskBlockManager.createTempLocalBlock()).thenAnswer(new Answer>() { + @Override + public Tuple2 answer(InvocationOnMock invocationOnMock) throws Throwable { + TempLocalBlockId blockId = new TempLocalBlockId(UUID.randomUUID()); + File file = File.createTempFile("spillFile", ".spill", tempDir); + return Tuple2$.MODULE$.apply(blockId, file); + } + }); + when(blockManager.getDiskWriter( + any(BlockId.class), + any(File.class), + any(SerializerInstance.class), + anyInt(), + any(ShuffleWriteMetrics.class))).thenAnswer(new Answer() { + @Override + public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Throwable { + Object[] args = invocationOnMock.getArguments(); + + return new DiskBlockObjectWriter( + (BlockId) args[0], + (File) args[1], + (SerializerInstance) args[2], + (Integer) args[3], + new CompressStream(), + false, + (ShuffleWriteMetrics) args[4] + ); + } + }); + when(blockManager.wrapForCompression(any(BlockId.class), any(InputStream.class))) + .then(returnsSecondArg()); + } + + private static void insertNumber(UnsafeExternalSorter sorter, int value) throws Exception { + final int[] arr = new int[] { value }; + sorter.insertRecord(arr, PlatformDependent.INT_ARRAY_OFFSET, 4, value); } /** @@ -49,88 +137,36 @@ private static String getStringFromDataPage(Object baseObject, long baseOffset) */ @Test public void testSortingOnlyByPartitionId() throws Exception { - final String[] dataToSort = new String[] { - "Boba", - "Pearls", - "Tapioca", - "Taho", - "Condensed Milk", - "Jasmine", - "Milk Tea", - "Lychee", - "Mango" - }; - final TaskMemoryManager memoryManager = - new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); - final MemoryBlock dataPage = memoryManager.allocatePage(2048); - final Object baseObject = dataPage.getBaseObject(); - // Write the records into the data page: - long position = dataPage.getBaseOffset(); - for (String str : dataToSort) { - final byte[] strBytes = str.getBytes("utf-8"); - PlatformDependent.UNSAFE.putLong(baseObject, position, strBytes.length); - position += 8; - PlatformDependent.copyMemory( - strBytes, - PlatformDependent.BYTE_ARRAY_OFFSET, - baseObject, - position, - strBytes.length); - position += strBytes.length; - } - // Since the key fits within the 8-byte prefix, we don't need to do any record comparison, so - // use a dummy comparator - final UnsafeSorter.RecordComparator recordComparator = new UnsafeSorter.RecordComparator() { - @Override - public int compare( - Object leftBaseObject, - long leftBaseOffset, - Object rightBaseObject, - long rightBaseOffset) { - return 0; - } - }; - // Compute key prefixes based on the records' partition ids - final HashPartitioner hashPartitioner = new HashPartitioner(4); - // Use integer comparison for comparing prefixes (which are partition ids, in this case) - final UnsafeSorter.PrefixComparator prefixComparator = new UnsafeSorter.PrefixComparator() { - @Override - public int compare(long prefix1, long prefix2) { - return (int) prefix1 - (int) prefix2; - } - }; - final UnsafeSorter sorter = new UnsafeSorter( + + final UnsafeExternalSorter sorter = new UnsafeExternalSorter( memoryManager, + shuffleMemoryManager, + blockManager, recordComparator, prefixComparator, - dataToSort.length); - // Given a page of records, insert those records into the sorter one-by-one: - position = dataPage.getBaseOffset(); - for (int i = 0; i < dataToSort.length; i++) { - // position now points to the start of a record (which holds its length). - final long recordLength = PlatformDependent.UNSAFE.getLong(baseObject, position); - final long address = memoryManager.encodePageNumberAndOffset(dataPage, position); - final String str = getStringFromDataPage(baseObject, position); - final int partitionId = hashPartitioner.getPartition(str); - sorter.insertRecord(address, partitionId); - position += 8 + recordLength; - } - final Iterator iter = sorter.getSortedIterator(); - int iterLength = 0; - long prevPrefix = -1; - Arrays.sort(dataToSort); - while (iter.hasNext()) { - final UnsafeSorter.RecordPointerAndKeyPrefix pointerAndPrefix = iter.next(); - final Object recordBaseObject = memoryManager.getPage(pointerAndPrefix.recordPointer); - final long recordBaseOffset = memoryManager.getOffsetInPage(pointerAndPrefix.recordPointer); - final String str = getStringFromDataPage(recordBaseObject, recordBaseOffset); - Assert.assertTrue("String should be valid", Arrays.binarySearch(dataToSort, str) != -1); - Assert.assertTrue("Prefix " + pointerAndPrefix.keyPrefix + " should be >= previous prefix " + - prevPrefix, pointerAndPrefix.keyPrefix >= prevPrefix); - prevPrefix = pointerAndPrefix.keyPrefix; - iterLength++; - } - Assert.assertEquals(dataToSort.length, iterLength); + 1024, + new SparkConf()); + + insertNumber(sorter, 5); + insertNumber(sorter, 1); + insertNumber(sorter, 3); + sorter.spill(); + insertNumber(sorter, 4); + insertNumber(sorter, 2); + + Iterator iter = + sorter.getSortedIterator(); + + Assert.assertEquals(1, iter.next().keyPrefix); + Assert.assertEquals(2, iter.next().keyPrefix); + Assert.assertEquals(3, iter.next().keyPrefix); + Assert.assertEquals(4, iter.next().keyPrefix); + Assert.assertEquals(5, iter.next().keyPrefix); + Assert.assertFalse(iter.hasNext()); + // TODO: check that the values are also read back properly. + + // TODO: test for cleanup: + // assert(tempDir.isEmpty) } } diff --git a/pom.xml b/pom.xml index c85c5feeaf383..57d857340f735 100644 --- a/pom.xml +++ b/pom.xml @@ -652,7 +652,7 @@ org.mockito mockito-all - 1.9.0 + 1.9.5 test