From f156a8f830af6867015324592801255a2bbcb17b Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 3 May 2015 23:23:14 -0700 Subject: [PATCH] Hacky metrics integration; refactor some interfaces. --- .../shuffle/unsafe/UnsafeShuffleWriter.java | 24 +++--- .../unsafe/sort/ExternalSorterIterator.java | 31 ++++++++ .../unsafe/sort/UnsafeExternalSorter.java | 76 ++++++++++++++++--- .../spark/unsafe/sort/UnsafeSorter.java | 23 ++++-- ...rger.java => UnsafeSorterSpillMerger.java} | 8 +- .../unsafe/sort/UnsafeSorterSpillReader.java | 2 +- .../unsafe/sort/UnsafeSorterSpillWriter.java | 6 +- .../sort/UnsafeExternalSorterSuite.java | 26 ++++--- 8 files changed, 151 insertions(+), 45 deletions(-) create mode 100644 core/src/main/java/org/apache/spark/unsafe/sort/ExternalSorterIterator.java rename core/src/main/java/org/apache/spark/unsafe/sort/{UnsafeExternalSortSpillMerger.java => UnsafeSorterSpillMerger.java} (94%) diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java index 0ea11e823d1d4..d142cf59d8085 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java @@ -18,7 +18,7 @@ package org.apache.spark.shuffle.unsafe; import org.apache.spark.*; -import org.apache.spark.unsafe.sort.UnsafeExternalSortSpillMerger; +import org.apache.spark.unsafe.sort.ExternalSorterIterator; import org.apache.spark.unsafe.sort.UnsafeExternalSorter; import scala.Option; import scala.Product2; @@ -28,7 +28,6 @@ import java.io.File; import java.io.IOException; import java.nio.ByteBuffer; -import java.util.Iterator; import java.util.LinkedList; import com.esotericsoftware.kryo.io.ByteBufferOutputStream; @@ -47,7 +46,7 @@ import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.memory.TaskMemoryManager; -import org.apache.spark.unsafe.sort.UnsafeSorter; + import static org.apache.spark.unsafe.sort.UnsafeSorter.*; // IntelliJ gets confused and claims that this class should be abstract, but this actually compiles @@ -64,7 +63,6 @@ public class UnsafeShuffleWriter implements ShuffleWriter { private final SerializerInstance serializer; private final Partitioner partitioner; private final ShuffleWriteMetrics writeMetrics; - private final LinkedList allocatedPages = new LinkedList(); private final int fileBufferSize; private MapStatus mapStatus = null; @@ -108,12 +106,13 @@ private void freeMemory() { // TODO: free sorter memory } - private Iterator sortRecords( + private ExternalSorterIterator sortRecords( scala.collection.Iterator> records) throws Exception { final UnsafeExternalSorter sorter = new UnsafeExternalSorter( memoryManager, SparkEnv$.MODULE$.get().shuffleMemoryManager(), SparkEnv$.MODULE$.get().blockManager(), + TaskContext.get(), RECORD_COMPARATOR, PREFIX_COMPARATOR, 4096, // Initial size (TODO: tune this!) @@ -145,8 +144,7 @@ private Iterator sortRe return sorter.getSortedIterator(); } - private long[] writeSortedRecordsToFile( - Iterator sortedRecords) throws IOException { + private long[] writeSortedRecordsToFile(ExternalSorterIterator sortedRecords) throws IOException { final File outputFile = shuffleBlockManager.getDataFile(shuffleId, mapId); final ShuffleBlockId blockId = new ShuffleBlockId(shuffleId, mapId, IndexShuffleBlockManager.NOOP_REDUCE_ID()); @@ -157,8 +155,8 @@ private long[] writeSortedRecordsToFile( final byte[] arr = new byte[SER_BUFFER_SIZE]; while (sortedRecords.hasNext()) { - final UnsafeExternalSortSpillMerger.RecordAddressAndKeyPrefix recordPointer = sortedRecords.next(); - final int partition = (int) recordPointer.keyPrefix; + sortedRecords.loadNext(); + final int partition = (int) sortedRecords.keyPrefix; assert (partition >= currentPartition); if (partition != currentPartition) { // Switch to the new partition @@ -172,13 +170,13 @@ private long[] writeSortedRecordsToFile( } PlatformDependent.copyMemory( - recordPointer.baseObject, - recordPointer.baseOffset + 4, + sortedRecords.baseObject, + sortedRecords.baseOffset + 4, arr, PlatformDependent.BYTE_ARRAY_OFFSET, - recordPointer.recordLength); + sortedRecords.recordLength); assert (writer != null); // To suppress an IntelliJ warning - writer.write(arr, 0, recordPointer.recordLength); + writer.write(arr, 0, sortedRecords.recordLength); // TODO: add a test that detects whether we leave this call out: writer.recordWritten(); } diff --git a/core/src/main/java/org/apache/spark/unsafe/sort/ExternalSorterIterator.java b/core/src/main/java/org/apache/spark/unsafe/sort/ExternalSorterIterator.java new file mode 100644 index 0000000000000..d53a0baaf351f --- /dev/null +++ b/core/src/main/java/org/apache/spark/unsafe/sort/ExternalSorterIterator.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.sort; + +public abstract class ExternalSorterIterator { + + public Object baseObject; + public long baseOffset; + public int recordLength; + public long keyPrefix; + + public abstract boolean hasNext(); + + public abstract void loadNext(); + +} diff --git a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeExternalSorter.java index 42669055bcf1c..bf0019c51703f 100644 --- a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeExternalSorter.java @@ -19,12 +19,15 @@ import com.google.common.annotations.VisibleForTesting; import org.apache.spark.SparkConf; +import org.apache.spark.TaskContext; import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.shuffle.ShuffleMemoryManager; import org.apache.spark.storage.BlockManager; import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.memory.TaskMemoryManager; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import java.io.IOException; import java.util.Iterator; @@ -37,16 +40,20 @@ */ public final class UnsafeExternalSorter { + private final Logger logger = LoggerFactory.getLogger(UnsafeExternalSorter.class); + private static final int PAGE_SIZE = 1024 * 1024; // TODO: tune this private final PrefixComparator prefixComparator; private final RecordComparator recordComparator; private final int initialSize; + private int numSpills = 0; private UnsafeSorter sorter; private final TaskMemoryManager memoryManager; private final ShuffleMemoryManager shuffleMemoryManager; private final BlockManager blockManager; + private final TaskContext taskContext; private final LinkedList allocatedPages = new LinkedList(); private final boolean spillingEnabled; private final int fileBufferSize; @@ -63,13 +70,15 @@ public UnsafeExternalSorter( TaskMemoryManager memoryManager, ShuffleMemoryManager shuffleMemoryManager, BlockManager blockManager, + TaskContext taskContext, RecordComparator recordComparator, PrefixComparator prefixComparator, int initialSize, - SparkConf conf) { + SparkConf conf) throws IOException { this.memoryManager = memoryManager; this.shuffleMemoryManager = shuffleMemoryManager; this.blockManager = blockManager; + this.taskContext = taskContext; this.recordComparator = recordComparator; this.prefixComparator = prefixComparator; this.initialSize = initialSize; @@ -81,9 +90,19 @@ public UnsafeExternalSorter( // TODO: metrics tracking + integration with shuffle write metrics - private void openSorter() { + private void openSorter() throws IOException { this.writeMetrics = new ShuffleWriteMetrics(); // TODO: connect write metrics to task metrics? + // TODO: move this sizing calculation logic into a static method of sorter: + final long memoryRequested = initialSize * 8L * 2; + if (spillingEnabled) { + final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryRequested); + if (memoryAcquired != memoryRequested) { + shuffleMemoryManager.release(memoryAcquired); + throw new IOException("Could not acquire memory!"); + } + } + this.sorter = new UnsafeSorter(memoryManager, recordComparator, prefixComparator, initialSize); } @@ -101,23 +120,52 @@ public void spill() throws IOException { spillWriter.write(baseObject, baseOffset, recordLength, recordPointer.keyPrefix); } spillWriter.close(); + final long sorterMemoryUsage = sorter.getMemoryUsage(); sorter = null; - freeMemory(); + shuffleMemoryManager.release(sorterMemoryUsage); + final long spillSize = freeMemory(); + taskContext.taskMetrics().incMemoryBytesSpilled(spillSize); + taskContext.taskMetrics().incDiskBytesSpilled(spillWriter.numberOfSpilledBytes()); + numSpills++; + final long threadId = Thread.currentThread().getId(); + // TODO: messy; log _before_ spill + logger.info("Thread " + threadId + " spilling in-memory map of " + + org.apache.spark.util.Utils.bytesToString(spillSize) + " to disk (" + + (numSpills + ((numSpills > 1) ? " times" : " time")) + " so far)"); openSorter(); } - private void freeMemory() { + private long freeMemory() { + long memoryFreed = 0; final Iterator iter = allocatedPages.iterator(); while (iter.hasNext()) { memoryManager.freePage(iter.next()); shuffleMemoryManager.release(PAGE_SIZE); + memoryFreed += PAGE_SIZE; iter.remove(); } currentPage = null; currentPagePosition = -1; + return memoryFreed; } private void ensureSpaceInDataPage(int requiredSpace) throws Exception { + // TODO: merge these steps to first calculate total memory requirements for this insert, + // then try to acquire; no point in acquiring sort buffer only to spill due to no space in the + // data page. + if (!sorter.hasSpaceForAnotherRecord() && spillingEnabled) { + final long oldSortBufferMemoryUsage = sorter.getMemoryUsage(); + final long memoryToGrowSortBuffer = oldSortBufferMemoryUsage * 2; + final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryToGrowSortBuffer); + if (memoryAcquired < memoryToGrowSortBuffer) { + shuffleMemoryManager.release(memoryAcquired); + spill(); + } else { + sorter.expandSortBuffer(); + shuffleMemoryManager.release(oldSortBufferMemoryUsage); + } + } + final long spaceInCurrentPage; if (currentPage != null) { spaceInCurrentPage = PAGE_SIZE - (currentPagePosition - currentPage.getBaseOffset()); @@ -129,12 +177,22 @@ private void ensureSpaceInDataPage(int requiredSpace) throws Exception { throw new Exception("Required space " + requiredSpace + " is greater than page size (" + PAGE_SIZE + ")"); } else if (requiredSpace > spaceInCurrentPage) { - if (spillingEnabled && shuffleMemoryManager.tryToAcquire(PAGE_SIZE) < PAGE_SIZE) { - spill(); + if (spillingEnabled) { + final long memoryAcquired = shuffleMemoryManager.tryToAcquire(PAGE_SIZE); + if (memoryAcquired < PAGE_SIZE) { + shuffleMemoryManager.release(memoryAcquired); + spill(); + final long memoryAcquiredAfterSpill = shuffleMemoryManager.tryToAcquire(PAGE_SIZE); + if (memoryAcquiredAfterSpill != PAGE_SIZE) { + shuffleMemoryManager.release(memoryAcquiredAfterSpill); + throw new Exception("Can't allocate memory!"); + } + } } currentPage = memoryManager.allocatePage(PAGE_SIZE); currentPagePosition = currentPage.getBaseOffset(); allocatedPages.add(currentPage); + logger.info("Acquired new page! " + allocatedPages.size() * PAGE_SIZE); } } @@ -162,9 +220,9 @@ public void insertRecord( sorter.insertRecord(recordAddress, prefix); } - public Iterator getSortedIterator() throws IOException { - final UnsafeExternalSortSpillMerger spillMerger = - new UnsafeExternalSortSpillMerger(recordComparator, prefixComparator); + public ExternalSorterIterator getSortedIterator() throws IOException { + final UnsafeSorterSpillMerger spillMerger = + new UnsafeSorterSpillMerger(recordComparator, prefixComparator); for (UnsafeSorterSpillWriter spillWriter : spillWriters) { spillMerger.addSpill(spillWriter.getReader(blockManager)); } diff --git a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java index 0f844c1997668..917cbdb564a15 100644 --- a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java +++ b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java @@ -86,10 +86,9 @@ public static abstract class PrefixComparator { */ private int sortBufferInsertPosition = 0; - private void expandSortBuffer(int newSize) { - assert (newSize > sortBuffer.length); + public void expandSortBuffer() { final long[] oldBuffer = sortBuffer; - sortBuffer = new long[newSize]; + sortBuffer = new long[oldBuffer.length * 2]; System.arraycopy(oldBuffer, 0, sortBuffer, 0, oldBuffer.length); } @@ -122,14 +121,22 @@ public int compare(RecordPointerAndKeyPrefix left, RecordPointerAndKeyPrefix rig }; } + public long getMemoryUsage() { + return sortBuffer.length * 8L; + } + + public boolean hasSpaceForAnotherRecord() { + return sortBufferInsertPosition + 2 < sortBuffer.length; + } + /** * Insert a record into the sort buffer. * * @param objectAddress pointer to a record in a data page, encoded by {@link TaskMemoryManager}. */ public void insertRecord(long objectAddress, long keyPrefix) { - if (sortBufferInsertPosition + 2 == sortBuffer.length) { - expandSortBuffer(sortBuffer.length * 2); + if (!hasSpaceForAnotherRecord()) { + expandSortBuffer(); } sortBuffer[sortBufferInsertPosition] = objectAddress; sortBufferInsertPosition++; @@ -167,10 +174,10 @@ public void remove() { }; } - public UnsafeExternalSortSpillMerger.MergeableIterator getMergeableIterator() { + public UnsafeSorterSpillMerger.MergeableIterator getMergeableIterator() { sorter.sort(sortBuffer, 0, sortBufferInsertPosition / 2, sortComparator); - UnsafeExternalSortSpillMerger.MergeableIterator iter = - new UnsafeExternalSortSpillMerger.MergeableIterator() { + UnsafeSorterSpillMerger.MergeableIterator iter = + new UnsafeSorterSpillMerger.MergeableIterator() { private int position = 0; private Object baseObject; diff --git a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeExternalSortSpillMerger.java b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillMerger.java similarity index 94% rename from core/src/main/java/org/apache/spark/unsafe/sort/UnsafeExternalSortSpillMerger.java rename to core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillMerger.java index c6bd4ee9df4ff..93278d5a26473 100644 --- a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeExternalSortSpillMerger.java +++ b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillMerger.java @@ -23,7 +23,7 @@ import static org.apache.spark.unsafe.sort.UnsafeSorter.*; -public final class UnsafeExternalSortSpillMerger { +public final class UnsafeSorterSpillMerger { private final PriorityQueue priorityQueue; @@ -46,9 +46,9 @@ public static final class RecordAddressAndKeyPrefix { public long keyPrefix; } - public UnsafeExternalSortSpillMerger( - final RecordComparator recordComparator, - final UnsafeSorter.PrefixComparator prefixComparator) { + public UnsafeSorterSpillMerger( + final RecordComparator recordComparator, + final UnsafeSorter.PrefixComparator prefixComparator) { final Comparator comparator = new Comparator() { @Override diff --git a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillReader.java index 7c696240aaa73..894a593d41f3e 100644 --- a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillReader.java +++ b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillReader.java @@ -24,7 +24,7 @@ import java.io.*; -public final class UnsafeSorterSpillReader extends UnsafeExternalSortSpillMerger.MergeableIterator { +final class UnsafeSorterSpillReader extends UnsafeSorterSpillMerger.MergeableIterator { private final File file; private InputStream in; diff --git a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillWriter.java b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillWriter.java index e0649122ac09c..6085df67d2c2e 100644 --- a/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillWriter.java +++ b/core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillWriter.java @@ -33,7 +33,7 @@ import java.io.*; import java.nio.ByteBuffer; -public final class UnsafeSorterSpillWriter { +final class UnsafeSorterSpillWriter { private static final int SER_BUFFER_SIZE = 1024 * 1024; // TODO: tune this public static final int EOF_MARKER = -1; @@ -122,6 +122,10 @@ public void close() throws IOException { arr = null; } + public long numberOfSpilledBytes() { + return file.length(); + } + public UnsafeSorterSpillReader getReader(BlockManager blockManager) throws IOException { return new UnsafeSorterSpillReader(blockManager, file, blockId); } diff --git a/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeExternalSorterSuite.java index 4f2aa9b895c01..e745074af075c 100644 --- a/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/sort/UnsafeExternalSorterSuite.java @@ -20,6 +20,8 @@ import org.apache.spark.HashPartitioner; import org.apache.spark.SparkConf; +import org.apache.spark.TaskContext; +import org.apache.spark.TaskContextImpl; import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.serializer.SerializerInstance; import org.apache.spark.shuffle.ShuffleMemoryManager; @@ -41,7 +43,6 @@ import java.io.File; import java.io.InputStream; import java.io.OutputStream; -import java.nio.file.Files; import java.util.Iterator; import java.util.UUID; @@ -78,6 +79,7 @@ public int compare( BlockManager blockManager; DiskBlockManager diskBlockManager; File tempDir; + TaskContext taskContext; private static final class CompressStream extends AbstractFunction1 { @Override @@ -92,6 +94,7 @@ public void setUp() { diskBlockManager = mock(DiskBlockManager.class); blockManager = mock(BlockManager.class); tempDir = new File(Utils.createTempDir$default$1()); + taskContext = mock(TaskContext.class); when(shuffleMemoryManager.tryToAcquire(anyLong())).then(returnsFirstArg()); when(blockManager.diskBlockManager()).thenReturn(diskBlockManager); when(diskBlockManager.createTempLocalBlock()).thenAnswer(new Answer>() { @@ -142,6 +145,7 @@ public void testSortingOnlyByPartitionId() throws Exception { memoryManager, shuffleMemoryManager, blockManager, + taskContext, recordComparator, prefixComparator, 1024, @@ -154,14 +158,18 @@ public void testSortingOnlyByPartitionId() throws Exception { insertNumber(sorter, 4); insertNumber(sorter, 2); - Iterator iter = - sorter.getSortedIterator(); - - Assert.assertEquals(1, iter.next().keyPrefix); - Assert.assertEquals(2, iter.next().keyPrefix); - Assert.assertEquals(3, iter.next().keyPrefix); - Assert.assertEquals(4, iter.next().keyPrefix); - Assert.assertEquals(5, iter.next().keyPrefix); + ExternalSorterIterator iter = sorter.getSortedIterator(); + + iter.loadNext(); + Assert.assertEquals(1, iter.keyPrefix); + iter.loadNext(); + Assert.assertEquals(2, iter.keyPrefix); + iter.loadNext(); + Assert.assertEquals(3, iter.keyPrefix); + iter.loadNext(); + Assert.assertEquals(4, iter.keyPrefix); + iter.loadNext(); + Assert.assertEquals(5, iter.keyPrefix); Assert.assertFalse(iter.hasNext()); // TODO: check that the values are also read back properly.