From 5e100b2db9c0cb986f7b90fd58d7274b83236241 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 3 May 2015 19:19:38 -0700 Subject: [PATCH] Super-messy WIP on external sort --- .../shuffle/unsafe/UnsafeShuffleWriter.java | 79 ++------ .../sort/UnsafeExternalSortSpillMerger.java | 106 +++++++++++ .../unsafe/sort/UnsafeExternalSorter.java | 172 ++++++++++++++++++ .../spark/unsafe/sort/UnsafeSorter.java | 54 ++++-- .../unsafe/sort/UnsafeSorterSpillReader.java | 93 ++++++++++ .../unsafe/sort/UnsafeSorterSpillWriter.java | 86 +++++++++ .../spark/util/collection/Spillable.scala | 14 +- .../sort/UnsafeExternalSorterSuite.java | 136 ++++++++++++++ 8 files changed, 663 insertions(+), 77 deletions(-) create mode 100644 core/src/main/java/org/apache/spark/unsafe/sort/UnsafeExternalSortSpillMerger.java create mode 100644 core/src/main/java/org/apache/spark/unsafe/sort/UnsafeExternalSorter.java create mode 100644 core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillReader.java create mode 100644 core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillWriter.java create mode 100644 core/src/test/java/org/apache/spark/unsafe/sort/UnsafeExternalSorterSuite.java diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java index 9554298c0f3f8..0ea11e823d1d4 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java @@ -17,6 +17,9 @@ package org.apache.spark.shuffle.unsafe; +import org.apache.spark.*; +import org.apache.spark.unsafe.sort.UnsafeExternalSortSpillMerger; +import org.apache.spark.unsafe.sort.UnsafeExternalSorter; import scala.Option; import scala.Product2; import scala.reflect.ClassTag; @@ -30,10 +33,6 @@ import com.esotericsoftware.kryo.io.ByteBufferOutputStream; -import org.apache.spark.Partitioner; -import org.apache.spark.ShuffleDependency; -import org.apache.spark.SparkEnv; -import org.apache.spark.TaskContext; import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.scheduler.MapStatus; import org.apache.spark.scheduler.MapStatus$; @@ -54,7 +53,6 @@ // IntelliJ gets confused and claims that this class should be abstract, but this actually compiles public class UnsafeShuffleWriter implements ShuffleWriter { - private static final int PAGE_SIZE = 1024 * 1024; // TODO: tune this private static final int SER_BUFFER_SIZE = 1024 * 1024; // TODO: tune this private static final ClassTag OBJECT_CLASS_TAG = ClassTag$.MODULE$.Object(); @@ -70,9 +68,6 @@ public class UnsafeShuffleWriter implements ShuffleWriter { private final int fileBufferSize; private MapStatus mapStatus = null; - private MemoryBlock currentPage = null; - private long currentPagePosition = -1; - /** * Are we in the process of stopping? Because map tasks can call stop() with success = true * and then call stop() with success = false if they get an exception, we want to make sure @@ -109,39 +104,20 @@ public void write(scala.collection.Iterator> records) { } } - private void ensureSpaceInDataPage(long requiredSpace) throws Exception { - final long spaceInCurrentPage; - if (currentPage != null) { - spaceInCurrentPage = PAGE_SIZE - (currentPagePosition - currentPage.getBaseOffset()); - } else { - spaceInCurrentPage = 0; - } - if (requiredSpace > PAGE_SIZE) { - // TODO: throw a more specific exception? - throw new Exception("Required space " + requiredSpace + " is greater than page size (" + - PAGE_SIZE + ")"); - } else if (requiredSpace > spaceInCurrentPage) { - currentPage = memoryManager.allocatePage(PAGE_SIZE); - currentPagePosition = currentPage.getBaseOffset(); - allocatedPages.add(currentPage); - } - } - private void freeMemory() { - final Iterator iter = allocatedPages.iterator(); - while (iter.hasNext()) { - memoryManager.freePage(iter.next()); - iter.remove(); - } + // TODO: free sorter memory } - private Iterator sortRecords( - scala.collection.Iterator> records) throws Exception { - final UnsafeSorter sorter = new UnsafeSorter( + private Iterator sortRecords( + scala.collection.Iterator> records) throws Exception { + final UnsafeExternalSorter sorter = new UnsafeExternalSorter( memoryManager, + SparkEnv$.MODULE$.get().shuffleMemoryManager(), + SparkEnv$.MODULE$.get().blockManager(), RECORD_COMPARATOR, PREFIX_COMPARATOR, - 4096 // Initial size (TODO: tune this!) + 4096, // Initial size (TODO: tune this!) + SparkEnv$.MODULE$.get().conf() ); final byte[] serArray = new byte[SER_BUFFER_SIZE]; @@ -161,30 +137,16 @@ private Iterator sortRecords( final int serializedRecordSize = serByteBuffer.position(); assert (serializedRecordSize > 0); - // Need 4 bytes to store the record length. - ensureSpaceInDataPage(serializedRecordSize + 4); - - final long recordAddress = - memoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition); - final Object baseObject = currentPage.getBaseObject(); - PlatformDependent.UNSAFE.putInt(baseObject, currentPagePosition, serializedRecordSize); - currentPagePosition += 4; - PlatformDependent.copyMemory( - serArray, - PlatformDependent.BYTE_ARRAY_OFFSET, - baseObject, - currentPagePosition, - serializedRecordSize); - currentPagePosition += serializedRecordSize; - sorter.insertRecord(recordAddress, partitionId); + sorter.insertRecord( + serArray, PlatformDependent.BYTE_ARRAY_OFFSET, serializedRecordSize, partitionId); } return sorter.getSortedIterator(); } private long[] writeSortedRecordsToFile( - Iterator sortedRecords) throws IOException { + Iterator sortedRecords) throws IOException { final File outputFile = shuffleBlockManager.getDataFile(shuffleId, mapId); final ShuffleBlockId blockId = new ShuffleBlockId(shuffleId, mapId, IndexShuffleBlockManager.NOOP_REDUCE_ID()); @@ -195,7 +157,7 @@ private long[] writeSortedRecordsToFile( final byte[] arr = new byte[SER_BUFFER_SIZE]; while (sortedRecords.hasNext()) { - final RecordPointerAndKeyPrefix recordPointer = sortedRecords.next(); + final UnsafeExternalSortSpillMerger.RecordAddressAndKeyPrefix recordPointer = sortedRecords.next(); final int partition = (int) recordPointer.keyPrefix; assert (partition >= currentPartition); if (partition != currentPartition) { @@ -209,17 +171,14 @@ private long[] writeSortedRecordsToFile( blockManager.getDiskWriter(blockId, outputFile, serializer, fileBufferSize, writeMetrics); } - final Object baseObject = memoryManager.getPage(recordPointer.recordPointer); - final long baseOffset = memoryManager.getOffsetInPage(recordPointer.recordPointer); - final int recordLength = (int) PlatformDependent.UNSAFE.getLong(baseObject, baseOffset); PlatformDependent.copyMemory( - baseObject, - baseOffset + 4, + recordPointer.baseObject, + recordPointer.baseOffset + 4, arr, PlatformDependent.BYTE_ARRAY_OFFSET, - recordLength); + recordPointer.recordLength); assert (writer != null); // To suppress an IntelliJ warning - writer.write(arr, 0, recordLength); + writer.write(arr, 0, recordPointer.recordLength); // TODO: add a test that detects whether we leave this call out: writer.recordWritten(); } diff --git a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeExternalSortSpillMerger.java b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeExternalSortSpillMerger.java new file mode 100644 index 0000000000000..89928ffaa448d --- /dev/null +++ b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeExternalSortSpillMerger.java @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.sort; + +import java.util.Comparator; +import java.util.Iterator; +import java.util.PriorityQueue; + +import static org.apache.spark.unsafe.sort.UnsafeSorter.*; + +public final class UnsafeExternalSortSpillMerger { + + private final PriorityQueue priorityQueue; + + public static abstract class MergeableIterator { + public abstract boolean hasNext(); + + public abstract void advanceRecord(); + + public abstract long getPrefix(); + + public abstract Object getBaseObject(); + + public abstract long getBaseOffset(); + } + + public static final class RecordAddressAndKeyPrefix { + public Object baseObject; + public long baseOffset; + public int recordLength; + public long keyPrefix; + } + + public UnsafeExternalSortSpillMerger( + final RecordComparator recordComparator, + final UnsafeSorter.PrefixComparator prefixComparator) { + final Comparator comparator = new Comparator() { + + @Override + public int compare(MergeableIterator left, MergeableIterator right) { + final int prefixComparisonResult = + prefixComparator.compare(left.getPrefix(), right.getPrefix()); + if (prefixComparisonResult == 0) { + return recordComparator.compare( + left.getBaseObject(), left.getBaseOffset(), + right.getBaseObject(), right.getBaseOffset()); + } else { + return prefixComparisonResult; + } + } + }; + priorityQueue = new PriorityQueue(10, comparator); + } + + public void addSpill(MergeableIterator spillReader) { + priorityQueue.add(spillReader); + } + + public Iterator getSortedIterator() { + return new Iterator() { + + private MergeableIterator spillReader; + private final RecordAddressAndKeyPrefix record = new RecordAddressAndKeyPrefix(); + + @Override + public boolean hasNext() { + return spillReader.hasNext() || !priorityQueue.isEmpty(); + } + + @Override + public RecordAddressAndKeyPrefix next() { + if (spillReader != null) { + if (spillReader.hasNext()) { + priorityQueue.add(spillReader); + } + } + spillReader = priorityQueue.poll(); + record.baseObject = spillReader.getBaseObject(); + record.baseOffset = spillReader.getBaseOffset(); + record.keyPrefix = spillReader.getPrefix(); + return record; + } + + @Override + public void remove() { + throw new UnsupportedOperationException(); + } + }; + } + +} diff --git a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeExternalSorter.java new file mode 100644 index 0000000000000..613d07cf6a316 --- /dev/null +++ b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeExternalSorter.java @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.sort; + +import com.google.common.annotations.VisibleForTesting; +import org.apache.spark.SparkConf; +import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.shuffle.ShuffleMemoryManager; +import org.apache.spark.storage.BlockManager; +import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.memory.TaskMemoryManager; + +import java.io.IOException; +import java.util.Iterator; +import java.util.LinkedList; + +import static org.apache.spark.unsafe.sort.UnsafeSorter.*; + +/** + * External sorter based on {@link UnsafeSorter}. + */ +public final class UnsafeExternalSorter { + + private static final int PAGE_SIZE = 1024 * 1024; // TODO: tune this + private static final int SER_BUFFER_SIZE = 1024 * 1024; // TODO: tune this + + private final PrefixComparator prefixComparator; + private final RecordComparator recordComparator; + private final int initialSize; + private UnsafeSorter sorter; + + private final TaskMemoryManager memoryManager; + private final ShuffleMemoryManager shuffleMemoryManager; + private final BlockManager blockManager; + private final LinkedList allocatedPages = new LinkedList(); + private final boolean spillingEnabled; + private final int fileBufferSize; + private ShuffleWriteMetrics writeMetrics; + + + private MemoryBlock currentPage = null; + private long currentPagePosition = -1; + + private final LinkedList spillWriters = + new LinkedList(); + + public UnsafeExternalSorter( + TaskMemoryManager memoryManager, + ShuffleMemoryManager shuffleMemoryManager, + BlockManager blockManager, + RecordComparator recordComparator, + PrefixComparator prefixComparator, + int initialSize, + SparkConf conf) { + this.memoryManager = memoryManager; + this.shuffleMemoryManager = shuffleMemoryManager; + this.blockManager = blockManager; + this.recordComparator = recordComparator; + this.prefixComparator = prefixComparator; + this.initialSize = initialSize; + this.spillingEnabled = conf.getBoolean("spark.shuffle.spill", true); + // Use getSizeAsKb (not bytes) to maintain backwards compatibility for units + this.fileBufferSize = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; + openSorter(); + } + + // TODO: metrics tracking + integration with shuffle write metrics + + private void openSorter() { + this.writeMetrics = new ShuffleWriteMetrics(); + // TODO: connect write metrics to task metrics? + this.sorter = new UnsafeSorter(memoryManager, recordComparator, prefixComparator, initialSize); + } + + @VisibleForTesting + public void spill() throws IOException { + final UnsafeSorterSpillWriter spillWriter = + new UnsafeSorterSpillWriter(blockManager, fileBufferSize, writeMetrics); + final Iterator sortedRecords = sorter.getSortedIterator(); + while (sortedRecords.hasNext()) { + final RecordPointerAndKeyPrefix recordPointer = sortedRecords.next(); + final Object baseObject = memoryManager.getPage(recordPointer.recordPointer); + final long baseOffset = memoryManager.getOffsetInPage(recordPointer.recordPointer); + final int recordLength = (int) PlatformDependent.UNSAFE.getLong(baseObject, baseOffset); + spillWriter.write(baseObject, baseOffset, recordLength, recordPointer.keyPrefix); + } + spillWriter.close(); + sorter = null; + freeMemory(); + openSorter(); + } + + private void freeMemory() { + final Iterator iter = allocatedPages.iterator(); + while (iter.hasNext()) { + memoryManager.freePage(iter.next()); + iter.remove(); + } + } + + private void ensureSpaceInDataPage(int requiredSpace) throws Exception { + final long spaceInCurrentPage; + if (currentPage != null) { + spaceInCurrentPage = PAGE_SIZE - (currentPagePosition - currentPage.getBaseOffset()); + } else { + spaceInCurrentPage = 0; + } + if (requiredSpace > PAGE_SIZE) { + // TODO: throw a more specific exception? + throw new Exception("Required space " + requiredSpace + " is greater than page size (" + + PAGE_SIZE + ")"); + } else if (requiredSpace > spaceInCurrentPage) { + if (spillingEnabled && shuffleMemoryManager.tryToAcquire(PAGE_SIZE) < PAGE_SIZE) { + spill(); + } + currentPage = memoryManager.allocatePage(PAGE_SIZE); + currentPagePosition = currentPage.getBaseOffset(); + allocatedPages.add(currentPage); + } + } + + public void insertRecord( + Object recordBaseObject, + long recordBaseOffset, + int lengthInBytes, + long prefix) throws Exception { + // Need 4 bytes to store the record length. + ensureSpaceInDataPage(lengthInBytes + 4); + + final long recordAddress = + memoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition); + final Object dataPageBaseObject = currentPage.getBaseObject(); + PlatformDependent.UNSAFE.putInt(dataPageBaseObject, currentPagePosition, lengthInBytes); + currentPagePosition += 4; + PlatformDependent.copyMemory( + recordBaseObject, + recordBaseOffset, + dataPageBaseObject, + currentPagePosition, + lengthInBytes); + currentPagePosition += lengthInBytes; + + sorter.insertRecord(recordAddress, prefix); + } + + public Iterator getSortedIterator() throws IOException { + final UnsafeExternalSortSpillMerger spillMerger = + new UnsafeExternalSortSpillMerger(recordComparator, prefixComparator); + for (UnsafeSorterSpillWriter spillWriter : spillWriters) { + spillMerger.addSpill(spillWriter.getReader(blockManager)); + } + spillWriters.clear(); + spillMerger.addSpill(sorter.getMergeableIterator()); + return spillMerger.getSortedIterator(); + } +} diff --git a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java index 092e26f4ee1fc..1801585e2ed84 100644 --- a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java +++ b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java @@ -43,17 +43,6 @@ public static final class RecordPointerAndKeyPrefix { * A key prefix, for use in comparisons. */ public long keyPrefix; - - // TODO: this was a carryover from test code; may want to remove this - @Override - public int hashCode() { - throw new UnsupportedOperationException(); - } - - @Override - public boolean equals(Object obj) { - throw new UnsupportedOperationException(); - } } /** @@ -82,6 +71,7 @@ public static abstract class PrefixComparator { public abstract int compare(long prefix1, long prefix2); } + private final TaskMemoryManager memoryManager; private final Sorter sorter; private final Comparator sortComparator; @@ -96,7 +86,6 @@ public static abstract class PrefixComparator { */ private int sortBufferInsertPosition = 0; - private void expandSortBuffer(int newSize) { assert (newSize > sortBuffer.length); final long[] oldBuffer = sortBuffer; @@ -111,6 +100,7 @@ public UnsafeSorter( int initialSize) { assert (initialSize > 0); this.sortBuffer = new long[initialSize * 2]; + this.memoryManager = memoryManager; this.sorter = new Sorter(UnsafeSortDataFormat.INSTANCE); this.sortComparator = new Comparator() { @@ -176,4 +166,44 @@ public void remove() { } }; } + + public UnsafeExternalSortSpillMerger.MergeableIterator getMergeableIterator() { + sorter.sort(sortBuffer, 0, sortBufferInsertPosition / 2, sortComparator); + return new UnsafeExternalSortSpillMerger.MergeableIterator() { + + private int position = 0; + private Object baseObject; + private long baseOffset; + private long keyPrefix; + + @Override + public boolean hasNext() { + return position < sortBufferInsertPosition; + } + + @Override + public void advanceRecord() { + final long recordPointer = sortBuffer[position]; + baseObject = memoryManager.getPage(recordPointer); + baseOffset = memoryManager.getOffsetInPage(recordPointer); + keyPrefix = sortBuffer[position + 1]; + position += 2; + } + + @Override + public long getPrefix() { + return keyPrefix; + } + + @Override + public Object getBaseObject() { + return baseObject; + } + + @Override + public long getBaseOffset() { + return baseOffset; + } + }; + } } diff --git a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillReader.java new file mode 100644 index 0000000000000..e2d5e6a8faa10 --- /dev/null +++ b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillReader.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.sort; + +import com.google.common.io.ByteStreams; +import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.serializer.JavaSerializerInstance; +import org.apache.spark.serializer.SerializerInstance; +import org.apache.spark.storage.BlockId; +import org.apache.spark.storage.BlockManager; +import org.apache.spark.storage.BlockObjectWriter; +import org.apache.spark.storage.TempLocalBlockId; +import org.apache.spark.unsafe.PlatformDependent; +import scala.Tuple2; + +import java.io.*; + +public final class UnsafeSorterSpillReader extends UnsafeExternalSortSpillMerger.MergeableIterator { + + private final File file; + private InputStream in; + private DataInputStream din; + + private long keyPrefix; + private final byte[] arr = new byte[1024 * 1024]; // TODO: tune this (maybe grow dynamically)? + private final Object baseObject = arr; + private final long baseOffset = PlatformDependent.BYTE_ARRAY_OFFSET; + + public UnsafeSorterSpillReader( + BlockManager blockManager, + File file, + BlockId blockId) throws IOException { + this.file = file; + final BufferedInputStream bs = new BufferedInputStream(new FileInputStream(file)); + this.in = blockManager.wrapForCompression(blockId, bs); + this.din = new DataInputStream(this.in); + assert (file.length() > 0); + advanceRecord(); + } + + @Override + public boolean hasNext() { + return (in != null); + } + + @Override + public void advanceRecord() { + try { + final int recordLength = din.readInt(); + if (recordLength == UnsafeSorterSpillWriter.EOF_MARKER) { + in.close(); + in = null; + return; + } + keyPrefix = din.readLong(); + ByteStreams.readFully(in, arr, 0, recordLength); + + } catch (Exception e) { + PlatformDependent.throwException(e); + } + throw new IllegalStateException(); + } + + @Override + public long getPrefix() { + return keyPrefix; + } + + @Override + public Object getBaseObject() { + return baseObject; + } + + @Override + public long getBaseOffset() { + return baseOffset; + } +} diff --git a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillWriter.java b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillWriter.java new file mode 100644 index 0000000000000..fdda38d3f1c47 --- /dev/null +++ b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillWriter.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.sort; + +import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.serializer.JavaSerializerInstance; +import org.apache.spark.serializer.SerializerInstance; +import org.apache.spark.storage.BlockId; +import org.apache.spark.storage.BlockManager; +import org.apache.spark.storage.BlockObjectWriter; +import org.apache.spark.storage.TempLocalBlockId; +import org.apache.spark.unsafe.PlatformDependent; +import scala.Tuple2; + +import java.io.DataOutputStream; +import java.io.File; +import java.io.IOException; + +public final class UnsafeSorterSpillWriter { + + private static final int SER_BUFFER_SIZE = 1024 * 1024; // TODO: tune this + public static final int EOF_MARKER = -1; + byte[] arr = new byte[SER_BUFFER_SIZE]; + + private final File file; + private final BlockId blockId; + BlockObjectWriter writer; + DataOutputStream dos; + + public UnsafeSorterSpillWriter( + BlockManager blockManager, + int fileBufferSize, + ShuffleWriteMetrics writeMetrics) throws IOException { + final Tuple2 spilledFileInfo = + blockManager.diskBlockManager().createTempLocalBlock(); + this.file = spilledFileInfo._2(); + this.blockId = spilledFileInfo._1(); + // Dummy serializer: + final SerializerInstance ser = new JavaSerializerInstance(0, false, null); + writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, writeMetrics); + dos = new DataOutputStream(writer); + } + + public void write( + Object baseObject, + long baseOffset, + int recordLength, + long keyPrefix) throws IOException { + PlatformDependent.copyMemory( + baseObject, + baseOffset + 4, + arr, + PlatformDependent.BYTE_ARRAY_OFFSET, + recordLength); + dos.writeInt(recordLength); + dos.writeLong(keyPrefix); + writer.write(arr, 0, recordLength); + // TODO: add a test that detects whether we leave this call out: + writer.recordWritten(); + } + + public void close() throws IOException { + dos.writeInt(EOF_MARKER); + writer.commitAndClose(); + arr = null; + } + + public UnsafeSorterSpillReader getReader(BlockManager blockManager) throws IOException { + return new UnsafeSorterSpillReader(blockManager, file, blockId); + } +} \ No newline at end of file diff --git a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala index 747ecf075a397..841a4cd791c4c 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala @@ -20,11 +20,20 @@ package org.apache.spark.util.collection import org.apache.spark.Logging import org.apache.spark.SparkEnv +private[spark] object Spillable { + // Initial threshold for the size of a collection before we start tracking its memory usage + val initialMemoryThreshold: Long = + SparkEnv.get.conf.getLong("spark.shuffle.spill.initialMemoryThreshold", 5 * 1024 * 1024) +} + /** * Spills contents of an in-memory collection to disk when the memory threshold * has been exceeded. */ private[spark] trait Spillable[C] extends Logging { + + import Spillable._ + /** * Spills the current in-memory collection to disk, and releases the memory. * @@ -42,11 +51,6 @@ private[spark] trait Spillable[C] extends Logging { // Memory manager that can be used to acquire/release memory private[this] val shuffleMemoryManager = SparkEnv.get.shuffleMemoryManager - // Initial threshold for the size of a collection before we start tracking its memory usage - // Exposed for testing - private[this] val initialMemoryThreshold: Long = - SparkEnv.get.conf.getLong("spark.shuffle.spill.initialMemoryThreshold", 5 * 1024 * 1024) - // Threshold for this collection's size in bytes before we start tracking its memory usage // To avoid a large number of small spills, initialize this to a value orders of magnitude > 0 private[this] var myMemoryThreshold = initialMemoryThreshold diff --git a/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeExternalSorterSuite.java new file mode 100644 index 0000000000000..e4376f1cea4fc --- /dev/null +++ b/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeExternalSorterSuite.java @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.sort; + + +import org.apache.spark.HashPartitioner; +import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.memory.ExecutorMemoryManager; +import org.apache.spark.unsafe.memory.MemoryAllocator; +import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.memory.TaskMemoryManager; +import org.junit.Assert; +import org.junit.Test; + +import java.util.Arrays; +import java.util.Iterator; + +import static org.mockito.Mockito.mock; + +public class UnsafeExternalSorterSuite { + private static String getStringFromDataPage(Object baseObject, long baseOffset) { + final int strLength = (int) PlatformDependent.UNSAFE.getLong(baseObject, baseOffset); + final byte[] strBytes = new byte[strLength]; + PlatformDependent.copyMemory( + baseObject, + baseOffset + 8, + strBytes, + PlatformDependent.BYTE_ARRAY_OFFSET, strLength); + return new String(strBytes); + } + + /** + * Tests the type of sorting that's used in the non-combiner path of sort-based shuffle. + */ + @Test + public void testSortingOnlyByPartitionId() throws Exception { + final String[] dataToSort = new String[] { + "Boba", + "Pearls", + "Tapioca", + "Taho", + "Condensed Milk", + "Jasmine", + "Milk Tea", + "Lychee", + "Mango" + }; + final TaskMemoryManager memoryManager = + new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); + final MemoryBlock dataPage = memoryManager.allocatePage(2048); + final Object baseObject = dataPage.getBaseObject(); + // Write the records into the data page: + long position = dataPage.getBaseOffset(); + for (String str : dataToSort) { + final byte[] strBytes = str.getBytes("utf-8"); + PlatformDependent.UNSAFE.putLong(baseObject, position, strBytes.length); + position += 8; + PlatformDependent.copyMemory( + strBytes, + PlatformDependent.BYTE_ARRAY_OFFSET, + baseObject, + position, + strBytes.length); + position += strBytes.length; + } + // Since the key fits within the 8-byte prefix, we don't need to do any record comparison, so + // use a dummy comparator + final UnsafeSorter.RecordComparator recordComparator = new UnsafeSorter.RecordComparator() { + @Override + public int compare( + Object leftBaseObject, + long leftBaseOffset, + Object rightBaseObject, + long rightBaseOffset) { + return 0; + } + }; + // Compute key prefixes based on the records' partition ids + final HashPartitioner hashPartitioner = new HashPartitioner(4); + // Use integer comparison for comparing prefixes (which are partition ids, in this case) + final UnsafeSorter.PrefixComparator prefixComparator = new UnsafeSorter.PrefixComparator() { + @Override + public int compare(long prefix1, long prefix2) { + return (int) prefix1 - (int) prefix2; + } + }; + final UnsafeSorter sorter = new UnsafeSorter( + memoryManager, + recordComparator, + prefixComparator, + dataToSort.length); + // Given a page of records, insert those records into the sorter one-by-one: + position = dataPage.getBaseOffset(); + for (int i = 0; i < dataToSort.length; i++) { + // position now points to the start of a record (which holds its length). + final long recordLength = PlatformDependent.UNSAFE.getLong(baseObject, position); + final long address = memoryManager.encodePageNumberAndOffset(dataPage, position); + final String str = getStringFromDataPage(baseObject, position); + final int partitionId = hashPartitioner.getPartition(str); + sorter.insertRecord(address, partitionId); + position += 8 + recordLength; + } + final Iterator iter = sorter.getSortedIterator(); + int iterLength = 0; + long prevPrefix = -1; + Arrays.sort(dataToSort); + while (iter.hasNext()) { + final UnsafeSorter.RecordPointerAndKeyPrefix pointerAndPrefix = iter.next(); + final Object recordBaseObject = memoryManager.getPage(pointerAndPrefix.recordPointer); + final long recordBaseOffset = memoryManager.getOffsetInPage(pointerAndPrefix.recordPointer); + final String str = getStringFromDataPage(recordBaseObject, recordBaseOffset); + Assert.assertTrue("String should be valid", Arrays.binarySearch(dataToSort, str) != -1); + Assert.assertTrue("Prefix " + pointerAndPrefix.keyPrefix + " should be >= previous prefix " + + prevPrefix, pointerAndPrefix.keyPrefix >= prevPrefix); + prevPrefix = pointerAndPrefix.keyPrefix; + iterLength++; + } + Assert.assertEquals(dataToSort.length, iterLength); + } + +}