From 87e721b7501ba6f96db919384b52e90c7a8c8d91 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 1 May 2015 16:34:47 -0700 Subject: [PATCH] Renaming and comments --- .../unsafe/sort/UnsafeSortDataFormat.java | 15 ++--- .../spark/unsafe/sort/UnsafeSorter.java | 65 ++++++++++++++----- .../shuffle/unsafe/UnsafeShuffleManager.scala | 8 +-- .../spark/unsafe/sort/UnsafeSorterSuite.java | 6 +- 4 files changed, 64 insertions(+), 30 deletions(-) diff --git a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSortDataFormat.java b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSortDataFormat.java index 9955e3fcaabbb..290a87b70cad6 100644 --- a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSortDataFormat.java +++ b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSortDataFormat.java @@ -17,35 +17,34 @@ package org.apache.spark.unsafe.sort; -import static org.apache.spark.unsafe.sort.UnsafeSorter.KeyPointerAndPrefix; +import static org.apache.spark.unsafe.sort.UnsafeSorter.RecordPointerAndKeyPrefix; import org.apache.spark.util.collection.SortDataFormat; /** - * TODO: finish writing this description + * Supports sorting an array of (record pointer, key prefix) pairs. Used in {@link UnsafeSorter}. * * Within each long[] buffer, position {@code 2 * i} holds a pointer pointer to the record at * index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix. */ -final class UnsafeSortDataFormat - extends SortDataFormat { +final class UnsafeSortDataFormat extends SortDataFormat { public static final UnsafeSortDataFormat INSTANCE = new UnsafeSortDataFormat(); private UnsafeSortDataFormat() { } @Override - public KeyPointerAndPrefix getKey(long[] data, int pos) { + public RecordPointerAndKeyPrefix getKey(long[] data, int pos) { // Since we re-use keys, this method shouldn't be called. throw new UnsupportedOperationException(); } @Override - public KeyPointerAndPrefix newKey() { - return new KeyPointerAndPrefix(); + public RecordPointerAndKeyPrefix newKey() { + return new RecordPointerAndKeyPrefix(); } @Override - public KeyPointerAndPrefix getKey(long[] data, int pos, KeyPointerAndPrefix reuse) { + public RecordPointerAndKeyPrefix getKey(long[] data, int pos, RecordPointerAndKeyPrefix reuse) { reuse.recordPointer = data[pos * 2]; reuse.keyPrefix = data[pos * 2 + 1]; return reuse; diff --git a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java index d33ca321a9835..7795ee6a5f0e2 100644 --- a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java +++ b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java @@ -23,9 +23,16 @@ import org.apache.spark.util.collection.Sorter; import org.apache.spark.unsafe.memory.TaskMemoryManager; +/** + * Sorts records using an AlphaSort-style key-prefix sort. This sort stores pointers to records + * alongside a user-defined prefix of the record's sorting key. When the underlying sort algorithm + * compares records, it will first compare the stored key prefixes; if the prefixes are not equal, + * then we do not need to traverse the record pointers to compare the actual records. Avoiding these + * random memory accesses improves cache hit rates. + */ public final class UnsafeSorter { - public static final class KeyPointerAndPrefix { + public static final class RecordPointerAndKeyPrefix { /** * A pointer to a record; see {@link org.apache.spark.unsafe.memory.TaskMemoryManager} for a * description of how these addresses are encoded. @@ -37,6 +44,7 @@ public static final class KeyPointerAndPrefix { */ public long keyPrefix; + // TODO: this was a carryover from test code; may want to remove this @Override public int hashCode() { throw new UnsupportedOperationException(); @@ -48,7 +56,17 @@ public boolean equals(Object obj) { } } + /** + * Compares records for ordering. In cases where the entire sorting key can fit in the 8-byte + * prefix, this may simply return 0. + */ public static abstract class RecordComparator { + /** + * Compare two records for order. + * + * @return a negative integer, zero, or a positive integer as the first record is less than, + * equal to, or greater than the second. + */ public abstract int compare( Object leftBaseObject, long leftBaseOffset, @@ -56,13 +74,16 @@ public abstract int compare( long rightBaseOffset); } + /** + * Given a pointer to a record, computes a prefix. + */ public static abstract class PrefixComputer { public abstract long computePrefix(Object baseObject, long baseOffset); } /** - * Compares 8-byte key prefixes in prefix sort. Subclasses may implement type-specific comparisons, - * such as lexicographic comparison for strings. + * Compares 8-byte key prefixes in prefix sort. Subclasses may implement type-specific + * comparisons, such as lexicographic comparison for strings. */ public static abstract class PrefixComparator { public abstract int compare(long prefix1, long prefix2); @@ -70,8 +91,8 @@ public static abstract class PrefixComparator { private final TaskMemoryManager memoryManager; private final PrefixComputer prefixComputer; - private final Sorter sorter; - private final Comparator sortComparator; + private final Sorter sorter; + private final Comparator sortComparator; /** * Within this buffer, position {@code 2 * i} holds a pointer pointer to the record at @@ -79,8 +100,12 @@ public static abstract class PrefixComparator { */ private long[] sortBuffer; + /** + * The position in the sort buffer where new records can be inserted. + */ private int sortBufferInsertPosition = 0; + private void expandSortBuffer(int newSize) { assert (newSize > sortBuffer.length); final long[] oldBuffer = sortBuffer; @@ -99,11 +124,13 @@ public UnsafeSorter( this.memoryManager = memoryManager; this.prefixComputer = prefixComputer; this.sorter = - new Sorter(UnsafeSortDataFormat.INSTANCE); - this.sortComparator = new Comparator() { + new Sorter(UnsafeSortDataFormat.INSTANCE); + this.sortComparator = new Comparator() { @Override - public int compare(KeyPointerAndPrefix left, KeyPointerAndPrefix right) { - if (left.keyPrefix == right.keyPrefix) { + public int compare(RecordPointerAndKeyPrefix left, RecordPointerAndKeyPrefix right) { + final int prefixComparisonResult = + prefixComparator.compare(left.keyPrefix, right.keyPrefix); + if (prefixComparisonResult == 0) { final Object leftBaseObject = memoryManager.getPage(left.recordPointer); final long leftBaseOffset = memoryManager.getOffsetInPage(left.recordPointer); final Object rightBaseObject = memoryManager.getPage(right.recordPointer); @@ -111,12 +138,17 @@ public int compare(KeyPointerAndPrefix left, KeyPointerAndPrefix right) { return recordComparator.compare( leftBaseObject, leftBaseOffset, rightBaseObject, rightBaseOffset); } else { - return prefixComparator.compare(left.keyPrefix, right.keyPrefix); + return prefixComparisonResult; } } }; } + /** + * Insert a record into the sort buffer. + * + * @param objectAddress pointer to a record in a data page, encoded by {@link TaskMemoryManager}. + */ public void insertRecord(long objectAddress) { if (sortBufferInsertPosition + 2 == sortBuffer.length) { expandSortBuffer(sortBuffer.length * 2); @@ -130,11 +162,15 @@ public void insertRecord(long objectAddress) { sortBufferInsertPosition++; } - public Iterator getSortedIterator() { + /** + * Return an iterator over record pointers in sorted order. For efficiency, all calls to + * {@code next()} will return the same mutable object. + */ + public Iterator getSortedIterator() { sorter.sort(sortBuffer, 0, sortBufferInsertPosition / 2, sortComparator); - return new Iterator() { + return new Iterator() { private int position = 0; - private final KeyPointerAndPrefix keyPointerAndPrefix = new KeyPointerAndPrefix(); + private final RecordPointerAndKeyPrefix keyPointerAndPrefix = new RecordPointerAndKeyPrefix(); @Override public boolean hasNext() { @@ -142,7 +178,7 @@ public boolean hasNext() { } @Override - public KeyPointerAndPrefix next() { + public RecordPointerAndKeyPrefix next() { keyPointerAndPrefix.recordPointer = sortBuffer[position]; keyPointerAndPrefix.keyPrefix = sortBuffer[position + 1]; position += 2; @@ -155,5 +191,4 @@ public void remove() { } }; } - } diff --git a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala index 489bcf42cb448..7eb825641263b 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala @@ -32,7 +32,7 @@ import org.apache.spark.storage.{BlockObjectWriter, ShuffleBlockId} import org.apache.spark.unsafe.PlatformDependent import org.apache.spark.unsafe.memory.{MemoryBlock, TaskMemoryManager} import org.apache.spark.unsafe.sort.UnsafeSorter -import org.apache.spark.unsafe.sort.UnsafeSorter.{KeyPointerAndPrefix, PrefixComparator, PrefixComputer, RecordComparator} +import org.apache.spark.unsafe.sort.UnsafeSorter.{RecordPointerAndKeyPrefix, PrefixComparator, PrefixComputer, RecordComparator} private class UnsafeShuffleHandle[K, V]( shuffleId: Int, @@ -122,7 +122,7 @@ private class UnsafeShuffleWriter[K, V]( private[this] val serializer = Serializer.getSerializer(dep.serializer).newInstance() private def sortRecords( - records: Iterator[_ <: Product2[K, V]]): java.util.Iterator[KeyPointerAndPrefix] = { + records: Iterator[_ <: Product2[K, V]]): java.util.Iterator[RecordPointerAndKeyPrefix] = { val sorter = new UnsafeSorter( context.taskMemoryManager(), DummyRecordComparator, @@ -194,7 +194,7 @@ private class UnsafeShuffleWriter[K, V]( } private def writeSortedRecordsToFile( - sortedRecords: java.util.Iterator[KeyPointerAndPrefix]): Array[Long] = { + sortedRecords: java.util.Iterator[RecordPointerAndKeyPrefix]): Array[Long] = { val outputFile = shuffleBlockManager.getDataFile(dep.shuffleId, mapId) val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockManager.NOOP_REDUCE_ID) val partitionLengths = new Array[Long](partitioner.numPartitions) @@ -223,7 +223,7 @@ private class UnsafeShuffleWriter[K, V]( } while (sortedRecords.hasNext) { - val keyPointerAndPrefix: KeyPointerAndPrefix = sortedRecords.next() + val keyPointerAndPrefix: RecordPointerAndKeyPrefix = sortedRecords.next() val partition = keyPointerAndPrefix.keyPrefix.toInt if (partition != currentPartition) { switchToPartition(partition) diff --git a/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeSorterSuite.java b/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeSorterSuite.java index 4c3b982747693..2f88df1210bbc 100644 --- a/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeSorterSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeSorterSuite.java @@ -52,7 +52,7 @@ public void testSortingEmptyInput() { mock(UnsafeSorter.PrefixComputer.class), mock(UnsafeSorter.PrefixComparator.class), 100); - final Iterator iter = sorter.getSortedIterator(); + final Iterator iter = sorter.getSortedIterator(); assert(!iter.hasNext()); } @@ -130,12 +130,12 @@ public int compare(long prefix1, long prefix2) { sorter.insertRecord(address); position += 8 + recordLength; } - final Iterator iter = sorter.getSortedIterator(); + final Iterator iter = sorter.getSortedIterator(); int iterLength = 0; long prevPrefix = -1; Arrays.sort(dataToSort); while (iter.hasNext()) { - final UnsafeSorter.KeyPointerAndPrefix pointerAndPrefix = iter.next(); + final UnsafeSorter.RecordPointerAndKeyPrefix pointerAndPrefix = iter.next(); final Object recordBaseObject = memoryManager.getPage(pointerAndPrefix.recordPointer); final long recordBaseOffset = memoryManager.getOffsetInPage(pointerAndPrefix.recordPointer); final String str = getStringFromDataPage(recordBaseObject, recordBaseOffset);