diff --git a/core/pom.xml b/core/pom.xml index 262a3320db106..bfa49d0d6dc25 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -361,6 +361,16 @@ junit test + + org.hamcrest + hamcrest-core + test + + + org.hamcrest + hamcrest-library + test + com.novocode junit-interface diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/DummySerializerInstance.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/DummySerializerInstance.java new file mode 100644 index 0000000000000..3f746b886bc9b --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/DummySerializerInstance.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.unsafe; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.ByteBuffer; + +import scala.reflect.ClassTag; + +import org.apache.spark.serializer.DeserializationStream; +import org.apache.spark.serializer.SerializationStream; +import org.apache.spark.serializer.SerializerInstance; +import org.apache.spark.unsafe.PlatformDependent; + +/** + * Unfortunately, we need a serializer instance in order to construct a DiskBlockObjectWriter. + * Our shuffle write path doesn't actually use this serializer (since we end up calling the + * `write() OutputStream methods), but DiskBlockObjectWriter still calls some methods on it. To work + * around this, we pass a dummy no-op serializer. + */ +final class DummySerializerInstance extends SerializerInstance { + + public static final DummySerializerInstance INSTANCE = new DummySerializerInstance(); + + private DummySerializerInstance() { } + + @Override + public SerializationStream serializeStream(final OutputStream s) { + return new SerializationStream() { + @Override + public void flush() { + // Need to implement this because DiskObjectWriter uses it to flush the compression stream + try { + s.flush(); + } catch (IOException e) { + PlatformDependent.throwException(e); + } + } + + @Override + public SerializationStream writeObject(T t, ClassTag ev1) { + throw new UnsupportedOperationException(); + } + + @Override + public void close() { + // Need to implement this because DiskObjectWriter uses it to close the compression stream + try { + s.close(); + } catch (IOException e) { + PlatformDependent.throwException(e); + } + } + }; + } + + @Override + public ByteBuffer serialize(T t, ClassTag ev1) { + throw new UnsupportedOperationException(); + } + + @Override + public DeserializationStream deserializeStream(InputStream s) { + throw new UnsupportedOperationException(); + } + + @Override + public T deserialize(ByteBuffer bytes, ClassLoader loader, ClassTag ev1) { + throw new UnsupportedOperationException(); + } + + @Override + public T deserialize(ByteBuffer bytes, ClassTag ev1) { + throw new UnsupportedOperationException(); + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java new file mode 100644 index 0000000000000..4ee6a82c0423e --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.unsafe; + +/** + * Wrapper around an 8-byte word that holds a 24-bit partition number and 40-bit record pointer. + *

+ * Within the long, the data is laid out as follows: + *

+ *   [24 bit partition number][13 bit memory page number][27 bit offset in page]
+ * 
+ * This implies that the maximum addressable page size is 2^27 bits = 128 megabytes, assuming that + * our offsets in pages are not 8-byte-word-aligned. Since we have 2^13 pages (based off the + * 13-bit page numbers assigned by {@link org.apache.spark.unsafe.memory.TaskMemoryManager}), this + * implies that we can address 2^13 * 128 megabytes = 1 terabyte of RAM per task. + *

+ * Assuming word-alignment would allow for a 1 gigabyte maximum page size, but we leave this + * optimization to future work as it will require more careful design to ensure that addresses are + * properly aligned (e.g. by padding records). + */ +final class PackedRecordPointer { + + static final int MAXIMUM_PAGE_SIZE_BYTES = 1 << 27; // 128 megabytes + + /** + * The maximum partition identifier that can be encoded. Note that partition ids start from 0. + */ + static final int MAXIMUM_PARTITION_ID = (1 << 24) - 1; // 16777215 + + /** Bit mask for the lower 40 bits of a long. */ + private static final long MASK_LONG_LOWER_40_BITS = (1L << 40) - 1; + + /** Bit mask for the upper 24 bits of a long */ + private static final long MASK_LONG_UPPER_24_BITS = ~MASK_LONG_LOWER_40_BITS; + + /** Bit mask for the lower 27 bits of a long. */ + private static final long MASK_LONG_LOWER_27_BITS = (1L << 27) - 1; + + /** Bit mask for the lower 51 bits of a long. */ + private static final long MASK_LONG_LOWER_51_BITS = (1L << 51) - 1; + + /** Bit mask for the upper 13 bits of a long */ + private static final long MASK_LONG_UPPER_13_BITS = ~MASK_LONG_LOWER_51_BITS; + + /** + * Pack a record address and partition id into a single word. + * + * @param recordPointer a record pointer encoded by TaskMemoryManager. + * @param partitionId a shuffle partition id (maximum value of 2^24). + * @return a packed pointer that can be decoded using the {@link PackedRecordPointer} class. + */ + public static long packPointer(long recordPointer, int partitionId) { + assert (partitionId <= MAXIMUM_PARTITION_ID); + // Note that without word alignment we can address 2^27 bytes = 128 megabytes per page. + // Also note that this relies on some internals of how TaskMemoryManager encodes its addresses. + final long pageNumber = (recordPointer & MASK_LONG_UPPER_13_BITS) >>> 24; + final long compressedAddress = pageNumber | (recordPointer & MASK_LONG_LOWER_27_BITS); + return (((long) partitionId) << 40) | compressedAddress; + } + + private long packedRecordPointer; + + public void set(long packedRecordPointer) { + this.packedRecordPointer = packedRecordPointer; + } + + public int getPartitionId() { + return (int) ((packedRecordPointer & MASK_LONG_UPPER_24_BITS) >>> 40); + } + + public long getRecordPointer() { + final long pageNumber = (packedRecordPointer << 24) & MASK_LONG_UPPER_13_BITS; + final long offsetInPage = packedRecordPointer & MASK_LONG_LOWER_27_BITS; + return pageNumber | offsetInPage; + } + +} diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java new file mode 100644 index 0000000000000..7bac0dc0bbeb6 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.unsafe; + +import java.io.File; + +import org.apache.spark.storage.TempShuffleBlockId; + +/** + * Metadata for a block of data written by {@link UnsafeShuffleExternalSorter}. + */ +final class SpillInfo { + final long[] partitionLengths; + final File file; + final TempShuffleBlockId blockId; + + public SpillInfo(int numPartitions, File file, TempShuffleBlockId blockId) { + this.partitionLengths = new long[numPartitions]; + this.file = file; + this.blockId = blockId; + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java new file mode 100644 index 0000000000000..9e9ed94b7890c --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java @@ -0,0 +1,422 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.unsafe; + +import java.io.File; +import java.io.IOException; +import java.util.LinkedList; + +import scala.Tuple2; + +import com.google.common.annotations.VisibleForTesting; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.SparkConf; +import org.apache.spark.TaskContext; +import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.serializer.SerializerInstance; +import org.apache.spark.shuffle.ShuffleMemoryManager; +import org.apache.spark.storage.*; +import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.memory.TaskMemoryManager; +import org.apache.spark.util.Utils; + +/** + * An external sorter that is specialized for sort-based shuffle. + *

+ * Incoming records are appended to data pages. When all records have been inserted (or when the + * current thread's shuffle memory limit is reached), the in-memory records are sorted according to + * their partition ids (using a {@link UnsafeShuffleInMemorySorter}). The sorted records are then + * written to a single output file (or multiple files, if we've spilled). The format of the output + * files is the same as the format of the final output file written by + * {@link org.apache.spark.shuffle.sort.SortShuffleWriter}: each output partition's records are + * written as a single serialized, compressed stream that can be read with a new decompression and + * deserialization stream. + *

+ * Unlike {@link org.apache.spark.util.collection.ExternalSorter}, this sorter does not merge its + * spill files. Instead, this merging is performed in {@link UnsafeShuffleWriter}, which uses a + * specialized merge procedure that avoids extra serialization/deserialization. + */ +final class UnsafeShuffleExternalSorter { + + private final Logger logger = LoggerFactory.getLogger(UnsafeShuffleExternalSorter.class); + + private static final int PAGE_SIZE = PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES; + @VisibleForTesting + static final int DISK_WRITE_BUFFER_SIZE = 1024 * 1024; + @VisibleForTesting + static final int MAX_RECORD_SIZE = PAGE_SIZE - 4; + + private final int initialSize; + private final int numPartitions; + private final TaskMemoryManager memoryManager; + private final ShuffleMemoryManager shuffleMemoryManager; + private final BlockManager blockManager; + private final TaskContext taskContext; + private final ShuffleWriteMetrics writeMetrics; + + /** The buffer size to use when writing spills using DiskBlockObjectWriter */ + private final int fileBufferSizeBytes; + + /** + * Memory pages that hold the records being sorted. The pages in this list are freed when + * spilling, although in principle we could recycle these pages across spills (on the other hand, + * this might not be necessary if we maintained a pool of re-usable pages in the TaskMemoryManager + * itself). + */ + private final LinkedList allocatedPages = new LinkedList(); + + private final LinkedList spills = new LinkedList(); + + // These variables are reset after spilling: + private UnsafeShuffleInMemorySorter sorter; + private MemoryBlock currentPage = null; + private long currentPagePosition = -1; + private long freeSpaceInCurrentPage = 0; + + public UnsafeShuffleExternalSorter( + TaskMemoryManager memoryManager, + ShuffleMemoryManager shuffleMemoryManager, + BlockManager blockManager, + TaskContext taskContext, + int initialSize, + int numPartitions, + SparkConf conf, + ShuffleWriteMetrics writeMetrics) throws IOException { + this.memoryManager = memoryManager; + this.shuffleMemoryManager = shuffleMemoryManager; + this.blockManager = blockManager; + this.taskContext = taskContext; + this.initialSize = initialSize; + this.numPartitions = numPartitions; + // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided + this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; + + this.writeMetrics = writeMetrics; + initializeForWriting(); + } + + /** + * Allocates new sort data structures. Called when creating the sorter and after each spill. + */ + private void initializeForWriting() throws IOException { + // TODO: move this sizing calculation logic into a static method of sorter: + final long memoryRequested = initialSize * 8L; + final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryRequested); + if (memoryAcquired != memoryRequested) { + shuffleMemoryManager.release(memoryAcquired); + throw new IOException("Could not acquire " + memoryRequested + " bytes of memory"); + } + + this.sorter = new UnsafeShuffleInMemorySorter(initialSize); + } + + /** + * Sorts the in-memory records and writes the sorted records to an on-disk file. + * This method does not free the sort data structures. + * + * @param isLastFile if true, this indicates that we're writing the final output file and that the + * bytes written should be counted towards shuffle spill metrics rather than + * shuffle write metrics. + */ + private void writeSortedFile(boolean isLastFile) throws IOException { + + final ShuffleWriteMetrics writeMetricsToUse; + + if (isLastFile) { + // We're writing the final non-spill file, so we _do_ want to count this as shuffle bytes. + writeMetricsToUse = writeMetrics; + } else { + // We're spilling, so bytes written should be counted towards spill rather than write. + // Create a dummy WriteMetrics object to absorb these metrics, since we don't want to count + // them towards shuffle bytes written. + writeMetricsToUse = new ShuffleWriteMetrics(); + } + + // This call performs the actual sort. + final UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator sortedRecords = + sorter.getSortedIterator(); + + // Currently, we need to open a new DiskBlockObjectWriter for each partition; we can avoid this + // after SPARK-5581 is fixed. + BlockObjectWriter writer; + + // Small writes to DiskBlockObjectWriter will be fairly inefficient. Since there doesn't seem to + // be an API to directly transfer bytes from managed memory to the disk writer, we buffer + // data through a byte array. This array does not need to be large enough to hold a single + // record; + final byte[] writeBuffer = new byte[DISK_WRITE_BUFFER_SIZE]; + + // Because this output will be read during shuffle, its compression codec must be controlled by + // spark.shuffle.compress instead of spark.shuffle.spill.compress, so we need to use + // createTempShuffleBlock here; see SPARK-3426 for more details. + final Tuple2 spilledFileInfo = + blockManager.diskBlockManager().createTempShuffleBlock(); + final File file = spilledFileInfo._2(); + final TempShuffleBlockId blockId = spilledFileInfo._1(); + final SpillInfo spillInfo = new SpillInfo(numPartitions, file, blockId); + + // Unfortunately, we need a serializer instance in order to construct a DiskBlockObjectWriter. + // Our write path doesn't actually use this serializer (since we end up calling the `write()` + // OutputStream methods), but DiskBlockObjectWriter still calls some methods on it. To work + // around this, we pass a dummy no-op serializer. + final SerializerInstance ser = DummySerializerInstance.INSTANCE; + + writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSizeBytes, writeMetricsToUse); + + int currentPartition = -1; + while (sortedRecords.hasNext()) { + sortedRecords.loadNext(); + final int partition = sortedRecords.packedRecordPointer.getPartitionId(); + assert (partition >= currentPartition); + if (partition != currentPartition) { + // Switch to the new partition + if (currentPartition != -1) { + writer.commitAndClose(); + spillInfo.partitionLengths[currentPartition] = writer.fileSegment().length(); + } + currentPartition = partition; + writer = + blockManager.getDiskWriter(blockId, file, ser, fileBufferSizeBytes, writeMetricsToUse); + } + + final long recordPointer = sortedRecords.packedRecordPointer.getRecordPointer(); + final Object recordPage = memoryManager.getPage(recordPointer); + final long recordOffsetInPage = memoryManager.getOffsetInPage(recordPointer); + int dataRemaining = PlatformDependent.UNSAFE.getInt(recordPage, recordOffsetInPage); + long recordReadPosition = recordOffsetInPage + 4; // skip over record length + while (dataRemaining > 0) { + final int toTransfer = Math.min(DISK_WRITE_BUFFER_SIZE, dataRemaining); + PlatformDependent.copyMemory( + recordPage, + recordReadPosition, + writeBuffer, + PlatformDependent.BYTE_ARRAY_OFFSET, + toTransfer); + writer.write(writeBuffer, 0, toTransfer); + recordReadPosition += toTransfer; + dataRemaining -= toTransfer; + } + writer.recordWritten(); + } + + if (writer != null) { + writer.commitAndClose(); + // If `writeSortedFile()` was called from `closeAndGetSpills()` and no records were inserted, + // then the file might be empty. Note that it might be better to avoid calling + // writeSortedFile() in that case. + if (currentPartition != -1) { + spillInfo.partitionLengths[currentPartition] = writer.fileSegment().length(); + spills.add(spillInfo); + } + } + + if (!isLastFile) { // i.e. this is a spill file + // The current semantics of `shuffleRecordsWritten` seem to be that it's updated when records + // are written to disk, not when they enter the shuffle sorting code. DiskBlockObjectWriter + // relies on its `recordWritten()` method being called in order to trigger periodic updates to + // `shuffleBytesWritten`. If we were to remove the `recordWritten()` call and increment that + // counter at a higher-level, then the in-progress metrics for records written and bytes + // written would get out of sync. + // + // When writing the last file, we pass `writeMetrics` directly to the DiskBlockObjectWriter; + // in all other cases, we pass in a dummy write metrics to capture metrics, then copy those + // metrics to the true write metrics here. The reason for performing this copying is so that + // we can avoid reporting spilled bytes as shuffle write bytes. + // + // Note that we intentionally ignore the value of `writeMetricsToUse.shuffleWriteTime()`. + // Consistent with ExternalSorter, we do not count this IO towards shuffle write time. + // This means that this IO time is not accounted for anywhere; SPARK-3577 will fix this. + writeMetrics.incShuffleRecordsWritten(writeMetricsToUse.shuffleRecordsWritten()); + taskContext.taskMetrics().incDiskBytesSpilled(writeMetricsToUse.shuffleBytesWritten()); + } + } + + /** + * Sort and spill the current records in response to memory pressure. + */ + @VisibleForTesting + void spill() throws IOException { + logger.info("Thread {} spilling sort data of {} to disk ({} {} so far)", + Thread.currentThread().getId(), + Utils.bytesToString(getMemoryUsage()), + spills.size(), + spills.size() > 1 ? " times" : " time"); + + writeSortedFile(false); + final long sorterMemoryUsage = sorter.getMemoryUsage(); + sorter = null; + shuffleMemoryManager.release(sorterMemoryUsage); + final long spillSize = freeMemory(); + taskContext.taskMetrics().incMemoryBytesSpilled(spillSize); + + initializeForWriting(); + } + + private long getMemoryUsage() { + return sorter.getMemoryUsage() + (allocatedPages.size() * (long) PAGE_SIZE); + } + + private long freeMemory() { + long memoryFreed = 0; + for (MemoryBlock block : allocatedPages) { + memoryManager.freePage(block); + shuffleMemoryManager.release(block.size()); + memoryFreed += block.size(); + } + allocatedPages.clear(); + currentPage = null; + currentPagePosition = -1; + freeSpaceInCurrentPage = 0; + return memoryFreed; + } + + /** + * Force all memory and spill files to be deleted; called by shuffle error-handling code. + */ + public void cleanupAfterError() { + freeMemory(); + for (SpillInfo spill : spills) { + if (spill.file.exists() && !spill.file.delete()) { + logger.error("Unable to delete spill file {}", spill.file.getPath()); + } + } + if (sorter != null) { + shuffleMemoryManager.release(sorter.getMemoryUsage()); + sorter = null; + } + } + + /** + * Checks whether there is enough space to insert a new record into the sorter. + * + * @param requiredSpace the required space in the data page, in bytes, including space for storing + * the record size. + + * @return true if the record can be inserted without requiring more allocations, false otherwise. + */ + private boolean haveSpaceForRecord(int requiredSpace) { + assert (requiredSpace > 0); + return (sorter.hasSpaceForAnotherRecord() && (requiredSpace <= freeSpaceInCurrentPage)); + } + + /** + * Allocates more memory in order to insert an additional record. This will request additional + * memory from the {@link ShuffleMemoryManager} and spill if the requested memory can not be + * obtained. + * + * @param requiredSpace the required space in the data page, in bytes, including space for storing + * the record size. + */ + private void allocateSpaceForRecord(int requiredSpace) throws IOException { + if (!sorter.hasSpaceForAnotherRecord()) { + logger.debug("Attempting to expand sort pointer array"); + final long oldPointerArrayMemoryUsage = sorter.getMemoryUsage(); + final long memoryToGrowPointerArray = oldPointerArrayMemoryUsage * 2; + final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryToGrowPointerArray); + if (memoryAcquired < memoryToGrowPointerArray) { + shuffleMemoryManager.release(memoryAcquired); + spill(); + } else { + sorter.expandPointerArray(); + shuffleMemoryManager.release(oldPointerArrayMemoryUsage); + } + } + if (requiredSpace > freeSpaceInCurrentPage) { + logger.trace("Required space {} is less than free space in current page ({})", requiredSpace, + freeSpaceInCurrentPage); + // TODO: we should track metrics on the amount of space wasted when we roll over to a new page + // without using the free space at the end of the current page. We should also do this for + // BytesToBytesMap. + if (requiredSpace > PAGE_SIZE) { + throw new IOException("Required space " + requiredSpace + " is greater than page size (" + + PAGE_SIZE + ")"); + } else { + final long memoryAcquired = shuffleMemoryManager.tryToAcquire(PAGE_SIZE); + if (memoryAcquired < PAGE_SIZE) { + shuffleMemoryManager.release(memoryAcquired); + spill(); + final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(PAGE_SIZE); + if (memoryAcquiredAfterSpilling != PAGE_SIZE) { + shuffleMemoryManager.release(memoryAcquiredAfterSpilling); + throw new IOException("Unable to acquire " + PAGE_SIZE + " bytes of memory"); + } + } + currentPage = memoryManager.allocatePage(PAGE_SIZE); + currentPagePosition = currentPage.getBaseOffset(); + freeSpaceInCurrentPage = PAGE_SIZE; + allocatedPages.add(currentPage); + } + } + } + + /** + * Write a record to the shuffle sorter. + */ + public void insertRecord( + Object recordBaseObject, + long recordBaseOffset, + int lengthInBytes, + int partitionId) throws IOException { + // Need 4 bytes to store the record length. + final int totalSpaceRequired = lengthInBytes + 4; + if (!haveSpaceForRecord(totalSpaceRequired)) { + allocateSpaceForRecord(totalSpaceRequired); + } + + final long recordAddress = + memoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition); + final Object dataPageBaseObject = currentPage.getBaseObject(); + PlatformDependent.UNSAFE.putInt(dataPageBaseObject, currentPagePosition, lengthInBytes); + currentPagePosition += 4; + freeSpaceInCurrentPage -= 4; + PlatformDependent.copyMemory( + recordBaseObject, + recordBaseOffset, + dataPageBaseObject, + currentPagePosition, + lengthInBytes); + currentPagePosition += lengthInBytes; + freeSpaceInCurrentPage -= lengthInBytes; + sorter.insertRecord(recordAddress, partitionId); + } + + /** + * Close the sorter, causing any buffered data to be sorted and written out to disk. + * + * @return metadata for the spill files written by this sorter. If no records were ever inserted + * into this sorter, then this will return an empty array. + * @throws IOException + */ + public SpillInfo[] closeAndGetSpills() throws IOException { + try { + if (sorter != null) { + // Do not count the final file towards the spill count. + writeSortedFile(true); + freeMemory(); + } + return spills.toArray(new SpillInfo[spills.size()]); + } catch (IOException e) { + cleanupAfterError(); + throw e; + } + } + +} diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorter.java new file mode 100644 index 0000000000000..5bab501da9364 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorter.java @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.unsafe; + +import java.util.Comparator; + +import org.apache.spark.util.collection.Sorter; + +final class UnsafeShuffleInMemorySorter { + + private final Sorter sorter; + private static final class SortComparator implements Comparator { + @Override + public int compare(PackedRecordPointer left, PackedRecordPointer right) { + return left.getPartitionId() - right.getPartitionId(); + } + } + private static final SortComparator SORT_COMPARATOR = new SortComparator(); + + /** + * An array of record pointers and partition ids that have been encoded by + * {@link PackedRecordPointer}. The sort operates on this array instead of directly manipulating + * records. + */ + private long[] pointerArray; + + /** + * The position in the pointer array where new records can be inserted. + */ + private int pointerArrayInsertPosition = 0; + + public UnsafeShuffleInMemorySorter(int initialSize) { + assert (initialSize > 0); + this.pointerArray = new long[initialSize]; + this.sorter = new Sorter(UnsafeShuffleSortDataFormat.INSTANCE); + } + + public void expandPointerArray() { + final long[] oldArray = pointerArray; + // Guard against overflow: + final int newLength = oldArray.length * 2 > 0 ? (oldArray.length * 2) : Integer.MAX_VALUE; + pointerArray = new long[newLength]; + System.arraycopy(oldArray, 0, pointerArray, 0, oldArray.length); + } + + public boolean hasSpaceForAnotherRecord() { + return pointerArrayInsertPosition + 1 < pointerArray.length; + } + + public long getMemoryUsage() { + return pointerArray.length * 8L; + } + + /** + * Inserts a record to be sorted. + * + * @param recordPointer a pointer to the record, encoded by the task memory manager. Due to + * certain pointer compression techniques used by the sorter, the sort can + * only operate on pointers that point to locations in the first + * {@link PackedRecordPointer#MAXIMUM_PAGE_SIZE_BYTES} bytes of a data page. + * @param partitionId the partition id, which must be less than or equal to + * {@link PackedRecordPointer#MAXIMUM_PARTITION_ID}. + */ + public void insertRecord(long recordPointer, int partitionId) { + if (!hasSpaceForAnotherRecord()) { + if (pointerArray.length == Integer.MAX_VALUE) { + throw new IllegalStateException("Sort pointer array has reached maximum size"); + } else { + expandPointerArray(); + } + } + pointerArray[pointerArrayInsertPosition] = + PackedRecordPointer.packPointer(recordPointer, partitionId); + pointerArrayInsertPosition++; + } + + /** + * An iterator-like class that's used instead of Java's Iterator in order to facilitate inlining. + */ + public static final class UnsafeShuffleSorterIterator { + + private final long[] pointerArray; + private final int numRecords; + final PackedRecordPointer packedRecordPointer = new PackedRecordPointer(); + private int position = 0; + + public UnsafeShuffleSorterIterator(int numRecords, long[] pointerArray) { + this.numRecords = numRecords; + this.pointerArray = pointerArray; + } + + public boolean hasNext() { + return position < numRecords; + } + + public void loadNext() { + packedRecordPointer.set(pointerArray[position]); + position++; + } + } + + /** + * Return an iterator over record pointers in sorted order. + */ + public UnsafeShuffleSorterIterator getSortedIterator() { + sorter.sort(pointerArray, 0, pointerArrayInsertPosition, SORT_COMPARATOR); + return new UnsafeShuffleSorterIterator(pointerArrayInsertPosition, pointerArray); + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSortDataFormat.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSortDataFormat.java new file mode 100644 index 0000000000000..a66d74ee44782 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSortDataFormat.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.unsafe; + +import org.apache.spark.util.collection.SortDataFormat; + +final class UnsafeShuffleSortDataFormat extends SortDataFormat { + + public static final UnsafeShuffleSortDataFormat INSTANCE = new UnsafeShuffleSortDataFormat(); + + private UnsafeShuffleSortDataFormat() { } + + @Override + public PackedRecordPointer getKey(long[] data, int pos) { + // Since we re-use keys, this method shouldn't be called. + throw new UnsupportedOperationException(); + } + + @Override + public PackedRecordPointer newKey() { + return new PackedRecordPointer(); + } + + @Override + public PackedRecordPointer getKey(long[] data, int pos, PackedRecordPointer reuse) { + reuse.set(data[pos]); + return reuse; + } + + @Override + public void swap(long[] data, int pos0, int pos1) { + final long temp = data[pos0]; + data[pos0] = data[pos1]; + data[pos1] = temp; + } + + @Override + public void copyElement(long[] src, int srcPos, long[] dst, int dstPos) { + dst[dstPos] = src[srcPos]; + } + + @Override + public void copyRange(long[] src, int srcPos, long[] dst, int dstPos, int length) { + System.arraycopy(src, srcPos, dst, dstPos, length); + } + + @Override + public long[] allocate(int length) { + return new long[length]; + } + +} diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java new file mode 100644 index 0000000000000..ad7eb04afcd8c --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java @@ -0,0 +1,438 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.unsafe; + +import java.io.*; +import java.nio.channels.FileChannel; +import java.util.Iterator; +import javax.annotation.Nullable; + +import scala.Option; +import scala.Product2; +import scala.collection.JavaConversions; +import scala.reflect.ClassTag; +import scala.reflect.ClassTag$; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.io.ByteStreams; +import com.google.common.io.Closeables; +import com.google.common.io.Files; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.*; +import org.apache.spark.annotation.Private; +import org.apache.spark.io.CompressionCodec; +import org.apache.spark.io.CompressionCodec$; +import org.apache.spark.io.LZFCompressionCodec; +import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.network.util.LimitedInputStream; +import org.apache.spark.scheduler.MapStatus; +import org.apache.spark.scheduler.MapStatus$; +import org.apache.spark.serializer.SerializationStream; +import org.apache.spark.serializer.Serializer; +import org.apache.spark.serializer.SerializerInstance; +import org.apache.spark.shuffle.IndexShuffleBlockResolver; +import org.apache.spark.shuffle.ShuffleMemoryManager; +import org.apache.spark.shuffle.ShuffleWriter; +import org.apache.spark.storage.BlockManager; +import org.apache.spark.storage.TimeTrackingOutputStream; +import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.memory.TaskMemoryManager; + +@Private +public class UnsafeShuffleWriter extends ShuffleWriter { + + private final Logger logger = LoggerFactory.getLogger(UnsafeShuffleWriter.class); + + private static final ClassTag OBJECT_CLASS_TAG = ClassTag$.MODULE$.Object(); + + @VisibleForTesting + static final int INITIAL_SORT_BUFFER_SIZE = 4096; + + private final BlockManager blockManager; + private final IndexShuffleBlockResolver shuffleBlockResolver; + private final TaskMemoryManager memoryManager; + private final ShuffleMemoryManager shuffleMemoryManager; + private final SerializerInstance serializer; + private final Partitioner partitioner; + private final ShuffleWriteMetrics writeMetrics; + private final int shuffleId; + private final int mapId; + private final TaskContext taskContext; + private final SparkConf sparkConf; + private final boolean transferToEnabled; + + private MapStatus mapStatus = null; + private UnsafeShuffleExternalSorter sorter = null; + + /** Subclass of ByteArrayOutputStream that exposes `buf` directly. */ + private static final class MyByteArrayOutputStream extends ByteArrayOutputStream { + public MyByteArrayOutputStream(int size) { super(size); } + public byte[] getBuf() { return buf; } + } + + private MyByteArrayOutputStream serBuffer; + private SerializationStream serOutputStream; + + /** + * Are we in the process of stopping? Because map tasks can call stop() with success = true + * and then call stop() with success = false if they get an exception, we want to make sure + * we don't try deleting files, etc twice. + */ + private boolean stopping = false; + + public UnsafeShuffleWriter( + BlockManager blockManager, + IndexShuffleBlockResolver shuffleBlockResolver, + TaskMemoryManager memoryManager, + ShuffleMemoryManager shuffleMemoryManager, + UnsafeShuffleHandle handle, + int mapId, + TaskContext taskContext, + SparkConf sparkConf) throws IOException { + final int numPartitions = handle.dependency().partitioner().numPartitions(); + if (numPartitions > UnsafeShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS()) { + throw new IllegalArgumentException( + "UnsafeShuffleWriter can only be used for shuffles with at most " + + UnsafeShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS() + " reduce partitions"); + } + this.blockManager = blockManager; + this.shuffleBlockResolver = shuffleBlockResolver; + this.memoryManager = memoryManager; + this.shuffleMemoryManager = shuffleMemoryManager; + this.mapId = mapId; + final ShuffleDependency dep = handle.dependency(); + this.shuffleId = dep.shuffleId(); + this.serializer = Serializer.getSerializer(dep.serializer()).newInstance(); + this.partitioner = dep.partitioner(); + this.writeMetrics = new ShuffleWriteMetrics(); + taskContext.taskMetrics().shuffleWriteMetrics_$eq(Option.apply(writeMetrics)); + this.taskContext = taskContext; + this.sparkConf = sparkConf; + this.transferToEnabled = sparkConf.getBoolean("spark.file.transferTo", true); + open(); + } + + /** + * This convenience method should only be called in test code. + */ + @VisibleForTesting + public void write(Iterator> records) throws IOException { + write(JavaConversions.asScalaIterator(records)); + } + + @Override + public void write(scala.collection.Iterator> records) throws IOException { + boolean success = false; + try { + while (records.hasNext()) { + insertRecordIntoSorter(records.next()); + } + closeAndWriteOutput(); + success = true; + } finally { + if (!success) { + sorter.cleanupAfterError(); + } + } + } + + private void open() throws IOException { + assert (sorter == null); + sorter = new UnsafeShuffleExternalSorter( + memoryManager, + shuffleMemoryManager, + blockManager, + taskContext, + INITIAL_SORT_BUFFER_SIZE, + partitioner.numPartitions(), + sparkConf, + writeMetrics); + serBuffer = new MyByteArrayOutputStream(1024 * 1024); + serOutputStream = serializer.serializeStream(serBuffer); + } + + @VisibleForTesting + void closeAndWriteOutput() throws IOException { + serBuffer = null; + serOutputStream = null; + final SpillInfo[] spills = sorter.closeAndGetSpills(); + sorter = null; + final long[] partitionLengths; + try { + partitionLengths = mergeSpills(spills); + } finally { + for (SpillInfo spill : spills) { + if (spill.file.exists() && ! spill.file.delete()) { + logger.error("Error while deleting spill file {}", spill.file.getPath()); + } + } + } + shuffleBlockResolver.writeIndexFile(shuffleId, mapId, partitionLengths); + mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); + } + + @VisibleForTesting + void insertRecordIntoSorter(Product2 record) throws IOException { + final K key = record._1(); + final int partitionId = partitioner.getPartition(key); + serBuffer.reset(); + serOutputStream.writeKey(key, OBJECT_CLASS_TAG); + serOutputStream.writeValue(record._2(), OBJECT_CLASS_TAG); + serOutputStream.flush(); + + final int serializedRecordSize = serBuffer.size(); + assert (serializedRecordSize > 0); + + sorter.insertRecord( + serBuffer.getBuf(), PlatformDependent.BYTE_ARRAY_OFFSET, serializedRecordSize, partitionId); + } + + @VisibleForTesting + void forceSorterToSpill() throws IOException { + assert (sorter != null); + sorter.spill(); + } + + /** + * Merge zero or more spill files together, choosing the fastest merging strategy based on the + * number of spills and the IO compression codec. + * + * @return the partition lengths in the merged file. + */ + private long[] mergeSpills(SpillInfo[] spills) throws IOException { + final File outputFile = shuffleBlockResolver.getDataFile(shuffleId, mapId); + final boolean compressionEnabled = sparkConf.getBoolean("spark.shuffle.compress", true); + final CompressionCodec compressionCodec = CompressionCodec$.MODULE$.createCodec(sparkConf); + final boolean fastMergeEnabled = + sparkConf.getBoolean("spark.shuffle.unsafe.fastMergeEnabled", true); + final boolean fastMergeIsSupported = + !compressionEnabled || compressionCodec instanceof LZFCompressionCodec; + try { + if (spills.length == 0) { + new FileOutputStream(outputFile).close(); // Create an empty file + return new long[partitioner.numPartitions()]; + } else if (spills.length == 1) { + // Here, we don't need to perform any metrics updates because the bytes written to this + // output file would have already been counted as shuffle bytes written. + Files.move(spills[0].file, outputFile); + return spills[0].partitionLengths; + } else { + final long[] partitionLengths; + // There are multiple spills to merge, so none of these spill files' lengths were counted + // towards our shuffle write count or shuffle write time. If we use the slow merge path, + // then the final output file's size won't necessarily be equal to the sum of the spill + // files' sizes. To guard against this case, we look at the output file's actual size when + // computing shuffle bytes written. + // + // We allow the individual merge methods to report their own IO times since different merge + // strategies use different IO techniques. We count IO during merge towards the shuffle + // shuffle write time, which appears to be consistent with the "not bypassing merge-sort" + // branch in ExternalSorter. + if (fastMergeEnabled && fastMergeIsSupported) { + // Compression is disabled or we are using an IO compression codec that supports + // decompression of concatenated compressed streams, so we can perform a fast spill merge + // that doesn't need to interpret the spilled bytes. + if (transferToEnabled) { + logger.debug("Using transferTo-based fast merge"); + partitionLengths = mergeSpillsWithTransferTo(spills, outputFile); + } else { + logger.debug("Using fileStream-based fast merge"); + partitionLengths = mergeSpillsWithFileStream(spills, outputFile, null); + } + } else { + logger.debug("Using slow merge"); + partitionLengths = mergeSpillsWithFileStream(spills, outputFile, compressionCodec); + } + // When closing an UnsafeShuffleExternalSorter that has already spilled once but also has + // in-memory records, we write out the in-memory records to a file but do not count that + // final write as bytes spilled (instead, it's accounted as shuffle write). The merge needs + // to be counted as shuffle write, but this will lead to double-counting of the final + // SpillInfo's bytes. + writeMetrics.decShuffleBytesWritten(spills[spills.length - 1].file.length()); + writeMetrics.incShuffleBytesWritten(outputFile.length()); + return partitionLengths; + } + } catch (IOException e) { + if (outputFile.exists() && !outputFile.delete()) { + logger.error("Unable to delete output file {}", outputFile.getPath()); + } + throw e; + } + } + + /** + * Merges spill files using Java FileStreams. This code path is slower than the NIO-based merge, + * {@link UnsafeShuffleWriter#mergeSpillsWithTransferTo(SpillInfo[], File)}, so it's only used in + * cases where the IO compression codec does not support concatenation of compressed data, or in + * cases where users have explicitly disabled use of {@code transferTo} in order to work around + * kernel bugs. + * + * @param spills the spills to merge. + * @param outputFile the file to write the merged data to. + * @param compressionCodec the IO compression codec, or null if shuffle compression is disabled. + * @return the partition lengths in the merged file. + */ + private long[] mergeSpillsWithFileStream( + SpillInfo[] spills, + File outputFile, + @Nullable CompressionCodec compressionCodec) throws IOException { + assert (spills.length >= 2); + final int numPartitions = partitioner.numPartitions(); + final long[] partitionLengths = new long[numPartitions]; + final InputStream[] spillInputStreams = new FileInputStream[spills.length]; + OutputStream mergedFileOutputStream = null; + + boolean threwException = true; + try { + for (int i = 0; i < spills.length; i++) { + spillInputStreams[i] = new FileInputStream(spills[i].file); + } + for (int partition = 0; partition < numPartitions; partition++) { + final long initialFileLength = outputFile.length(); + mergedFileOutputStream = + new TimeTrackingOutputStream(writeMetrics, new FileOutputStream(outputFile, true)); + if (compressionCodec != null) { + mergedFileOutputStream = compressionCodec.compressedOutputStream(mergedFileOutputStream); + } + + for (int i = 0; i < spills.length; i++) { + final long partitionLengthInSpill = spills[i].partitionLengths[partition]; + if (partitionLengthInSpill > 0) { + InputStream partitionInputStream = + new LimitedInputStream(spillInputStreams[i], partitionLengthInSpill); + if (compressionCodec != null) { + partitionInputStream = compressionCodec.compressedInputStream(partitionInputStream); + } + ByteStreams.copy(partitionInputStream, mergedFileOutputStream); + } + } + mergedFileOutputStream.flush(); + mergedFileOutputStream.close(); + partitionLengths[partition] = (outputFile.length() - initialFileLength); + } + threwException = false; + } finally { + // To avoid masking exceptions that caused us to prematurely enter the finally block, only + // throw exceptions during cleanup if threwException == false. + for (InputStream stream : spillInputStreams) { + Closeables.close(stream, threwException); + } + Closeables.close(mergedFileOutputStream, threwException); + } + return partitionLengths; + } + + /** + * Merges spill files by using NIO's transferTo to concatenate spill partitions' bytes. + * This is only safe when the IO compression codec and serializer support concatenation of + * serialized streams. + * + * @return the partition lengths in the merged file. + */ + private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) throws IOException { + assert (spills.length >= 2); + final int numPartitions = partitioner.numPartitions(); + final long[] partitionLengths = new long[numPartitions]; + final FileChannel[] spillInputChannels = new FileChannel[spills.length]; + final long[] spillInputChannelPositions = new long[spills.length]; + FileChannel mergedFileOutputChannel = null; + + boolean threwException = true; + try { + for (int i = 0; i < spills.length; i++) { + spillInputChannels[i] = new FileInputStream(spills[i].file).getChannel(); + } + // This file needs to opened in append mode in order to work around a Linux kernel bug that + // affects transferTo; see SPARK-3948 for more details. + mergedFileOutputChannel = new FileOutputStream(outputFile, true).getChannel(); + + long bytesWrittenToMergedFile = 0; + for (int partition = 0; partition < numPartitions; partition++) { + for (int i = 0; i < spills.length; i++) { + final long partitionLengthInSpill = spills[i].partitionLengths[partition]; + long bytesToTransfer = partitionLengthInSpill; + final FileChannel spillInputChannel = spillInputChannels[i]; + final long writeStartTime = System.nanoTime(); + while (bytesToTransfer > 0) { + final long actualBytesTransferred = spillInputChannel.transferTo( + spillInputChannelPositions[i], + bytesToTransfer, + mergedFileOutputChannel); + spillInputChannelPositions[i] += actualBytesTransferred; + bytesToTransfer -= actualBytesTransferred; + } + writeMetrics.incShuffleWriteTime(System.nanoTime() - writeStartTime); + bytesWrittenToMergedFile += partitionLengthInSpill; + partitionLengths[partition] += partitionLengthInSpill; + } + } + // Check the position after transferTo loop to see if it is in the right position and raise an + // exception if it is incorrect. The position will not be increased to the expected length + // after calling transferTo in kernel version 2.6.32. This issue is described at + // https://bugs.openjdk.java.net/browse/JDK-7052359 and SPARK-3948. + if (mergedFileOutputChannel.position() != bytesWrittenToMergedFile) { + throw new IOException( + "Current position " + mergedFileOutputChannel.position() + " does not equal expected " + + "position " + bytesWrittenToMergedFile + " after transferTo. Please check your kernel" + + " version to see if it is 2.6.32, as there is a kernel bug which will lead to " + + "unexpected behavior when using transferTo. You can set spark.file.transferTo=false " + + "to disable this NIO feature." + ); + } + threwException = false; + } finally { + // To avoid masking exceptions that caused us to prematurely enter the finally block, only + // throw exceptions during cleanup if threwException == false. + for (int i = 0; i < spills.length; i++) { + assert(spillInputChannelPositions[i] == spills[i].file.length()); + Closeables.close(spillInputChannels[i], threwException); + } + Closeables.close(mergedFileOutputChannel, threwException); + } + return partitionLengths; + } + + @Override + public Option stop(boolean success) { + try { + if (stopping) { + return Option.apply(null); + } else { + stopping = true; + if (success) { + if (mapStatus == null) { + throw new IllegalStateException("Cannot call stop(true) without having called write()"); + } + return Option.apply(mapStatus); + } else { + // The map task failed, so delete our output data. + shuffleBlockResolver.removeDataByMap(shuffleId, mapId); + return Option.apply(null); + } + } + } finally { + if (sorter != null) { + // If sorter is non-null, then this implies that we called stop() in response to an error, + // so we need to clean up memory and spill files created by the sorter + sorter.cleanupAfterError(); + } + } + } +} diff --git a/core/src/main/java/org/apache/spark/storage/TimeTrackingOutputStream.java b/core/src/main/java/org/apache/spark/storage/TimeTrackingOutputStream.java new file mode 100644 index 0000000000000..dc2aa30466cc6 --- /dev/null +++ b/core/src/main/java/org/apache/spark/storage/TimeTrackingOutputStream.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage; + +import java.io.IOException; +import java.io.OutputStream; + +import org.apache.spark.annotation.Private; +import org.apache.spark.executor.ShuffleWriteMetrics; + +/** + * Intercepts write calls and tracks total time spent writing in order to update shuffle write + * metrics. Not thread safe. + */ +@Private +public final class TimeTrackingOutputStream extends OutputStream { + + private final ShuffleWriteMetrics writeMetrics; + private final OutputStream outputStream; + + public TimeTrackingOutputStream(ShuffleWriteMetrics writeMetrics, OutputStream outputStream) { + this.writeMetrics = writeMetrics; + this.outputStream = outputStream; + } + + @Override + public void write(int b) throws IOException { + final long startTime = System.nanoTime(); + outputStream.write(b); + writeMetrics.incShuffleWriteTime(System.nanoTime() - startTime); + } + + @Override + public void write(byte[] b) throws IOException { + final long startTime = System.nanoTime(); + outputStream.write(b); + writeMetrics.incShuffleWriteTime(System.nanoTime() - startTime); + } + + @Override + public void write(byte[] b, int off, int len) throws IOException { + final long startTime = System.nanoTime(); + outputStream.write(b, off, len); + writeMetrics.incShuffleWriteTime(System.nanoTime() - startTime); + } + + @Override + public void flush() throws IOException { + final long startTime = System.nanoTime(); + outputStream.flush(); + writeMetrics.incShuffleWriteTime(System.nanoTime() - startTime); + } + + @Override + public void close() throws IOException { + final long startTime = System.nanoTime(); + outputStream.close(); + writeMetrics.incShuffleWriteTime(System.nanoTime() - startTime); + } +} diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 0c4d28f786edd..a5d831c7e68ad 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -313,7 +313,8 @@ object SparkEnv extends Logging { // Let the user specify short names for shuffle managers val shortShuffleMgrNames = Map( "hash" -> "org.apache.spark.shuffle.hash.HashShuffleManager", - "sort" -> "org.apache.spark.shuffle.sort.SortShuffleManager") + "sort" -> "org.apache.spark.shuffle.sort.SortShuffleManager", + "tungsten-sort" -> "org.apache.spark.shuffle.unsafe.UnsafeShuffleManager") val shuffleMgrName = conf.get("spark.shuffle.manager", "sort") val shuffleMgrClass = shortShuffleMgrNames.getOrElse(shuffleMgrName.toLowerCase, shuffleMgrName) val shuffleManager = instantiateClass[ShuffleManager](shuffleMgrClass) diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala index dfbde7c8a1b0d..698d1384d580d 100644 --- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala @@ -121,6 +121,8 @@ class JavaSerializer(conf: SparkConf) extends Serializer with Externalizable { private var counterReset = conf.getInt("spark.serializer.objectStreamReset", 100) private var extraDebugInfo = conf.getBoolean("spark.serializer.extraDebugInfo", true) + protected def this() = this(new SparkConf()) // For deserialization only + override def newInstance(): SerializerInstance = { val classLoader = defaultClassLoader.getOrElse(Thread.currentThread.getContextClassLoader) new JavaSerializerInstance(counterReset, extraDebugInfo, classLoader) diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala index 6ad427bcac7f9..6c3b3080d2605 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala @@ -76,7 +76,7 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf) private val consolidateShuffleFiles = conf.getBoolean("spark.shuffle.consolidateFiles", false) - // Use getSizeAsKb (not bytes) to maintain backwards compatibility of on units are provided + // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided private val bufferSize = conf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024 /** diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala index f6e6fe5defe09..4cc4ef5f1886e 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala @@ -17,14 +17,17 @@ package org.apache.spark.shuffle +import java.io.IOException + import org.apache.spark.scheduler.MapStatus /** * Obtained inside a map task to write out records to the shuffle system. */ -private[spark] trait ShuffleWriter[K, V] { +private[spark] abstract class ShuffleWriter[K, V] { /** Write a sequence of records to this task's output */ - def write(records: Iterator[_ <: Product2[K, V]]): Unit + @throws[IOException] + def write(records: Iterator[Product2[K, V]]): Unit /** Close this writer, passing along whether the map completed */ def stop(success: Boolean): Option[MapStatus] diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala index 897f0a5dc5bcc..eb87cee15903c 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala @@ -49,7 +49,7 @@ private[spark] class HashShuffleWriter[K, V]( writeMetrics) /** Write a bunch of records to this task's output */ - override def write(records: Iterator[_ <: Product2[K, V]]): Unit = { + override def write(records: Iterator[Product2[K, V]]): Unit = { val iter = if (dep.aggregator.isDefined) { if (dep.mapSideCombine) { dep.aggregator.get.combineValuesByKey(records, context) diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index 15842941daaab..d7fab351ca3b8 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -72,7 +72,7 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager true } - override def shuffleBlockResolver: IndexShuffleBlockResolver = { + override val shuffleBlockResolver: IndexShuffleBlockResolver = { indexShuffleBlockResolver } diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index add2656294ca2..c9dd6bfc4c219 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -48,7 +48,7 @@ private[spark] class SortShuffleWriter[K, V, C]( context.taskMetrics.shuffleWriteMetrics = Some(writeMetrics) /** Write a bunch of records to this task's output */ - override def write(records: Iterator[_ <: Product2[K, V]]): Unit = { + override def write(records: Iterator[Product2[K, V]]): Unit = { if (dep.mapSideCombine) { require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!") sorter = new ExternalSorter[K, V, C]( diff --git a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala new file mode 100644 index 0000000000000..f2bfef376d3ca --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.unsafe + +import java.util.Collections +import java.util.concurrent.ConcurrentHashMap + +import org.apache.spark._ +import org.apache.spark.serializer.Serializer +import org.apache.spark.shuffle._ +import org.apache.spark.shuffle.sort.SortShuffleManager + +/** + * Subclass of [[BaseShuffleHandle]], used to identify when we've chosen to use the new shuffle. + */ +private[spark] class UnsafeShuffleHandle[K, V]( + shuffleId: Int, + numMaps: Int, + dependency: ShuffleDependency[K, V, V]) + extends BaseShuffleHandle(shuffleId, numMaps, dependency) { +} + +private[spark] object UnsafeShuffleManager extends Logging { + + /** + * The maximum number of shuffle output partitions that UnsafeShuffleManager supports. + */ + val MAX_SHUFFLE_OUTPUT_PARTITIONS = PackedRecordPointer.MAXIMUM_PARTITION_ID + 1 + + /** + * Helper method for determining whether a shuffle should use the optimized unsafe shuffle + * path or whether it should fall back to the original sort-based shuffle. + */ + def canUseUnsafeShuffle[K, V, C](dependency: ShuffleDependency[K, V, C]): Boolean = { + val shufId = dependency.shuffleId + val serializer = Serializer.getSerializer(dependency.serializer) + if (!serializer.supportsRelocationOfSerializedObjects) { + log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because the serializer, " + + s"${serializer.getClass.getName}, does not support object relocation") + false + } else if (dependency.aggregator.isDefined) { + log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because an aggregator is defined") + false + } else if (dependency.keyOrdering.isDefined) { + log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because a key ordering is defined") + false + } else if (dependency.partitioner.numPartitions > MAX_SHUFFLE_OUTPUT_PARTITIONS) { + log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because it has more than " + + s"$MAX_SHUFFLE_OUTPUT_PARTITIONS partitions") + false + } else { + log.debug(s"Can use UnsafeShuffle for shuffle $shufId") + true + } + } +} + +/** + * A shuffle implementation that uses directly-managed memory to implement several performance + * optimizations for certain types of shuffles. In cases where the new performance optimizations + * cannot be applied, this shuffle manager delegates to [[SortShuffleManager]] to handle those + * shuffles. + * + * UnsafeShuffleManager's optimizations will apply when _all_ of the following conditions hold: + * + * - The shuffle dependency specifies no aggregation or output ordering. + * - The shuffle serializer supports relocation of serialized values (this is currently supported + * by KryoSerializer and Spark SQL's custom serializers). + * - The shuffle produces fewer than 16777216 output partitions. + * - No individual record is larger than 128 MB when serialized. + * + * In addition, extra spill-merging optimizations are automatically applied when the shuffle + * compression codec supports concatenation of serialized streams. This is currently supported by + * Spark's LZF serializer. + * + * At a high-level, UnsafeShuffleManager's design is similar to Spark's existing SortShuffleManager. + * In sort-based shuffle, incoming records are sorted according to their target partition ids, then + * written to a single map output file. Reducers fetch contiguous regions of this file in order to + * read their portion of the map output. In cases where the map output data is too large to fit in + * memory, sorted subsets of the output can are spilled to disk and those on-disk files are merged + * to produce the final output file. + * + * UnsafeShuffleManager optimizes this process in several ways: + * + * - Its sort operates on serialized binary data rather than Java objects, which reduces memory + * consumption and GC overheads. This optimization requires the record serializer to have certain + * properties to allow serialized records to be re-ordered without requiring deserialization. + * See SPARK-4550, where this optimization was first proposed and implemented, for more details. + * + * - It uses a specialized cache-efficient sorter ([[UnsafeShuffleExternalSorter]]) that sorts + * arrays of compressed record pointers and partition ids. By using only 8 bytes of space per + * record in the sorting array, this fits more of the array into cache. + * + * - The spill merging procedure operates on blocks of serialized records that belong to the same + * partition and does not need to deserialize records during the merge. + * + * - When the spill compression codec supports concatenation of compressed data, the spill merge + * simply concatenates the serialized and compressed spill partitions to produce the final output + * partition. This allows efficient data copying methods, like NIO's `transferTo`, to be used + * and avoids the need to allocate decompression or copying buffers during the merge. + * + * For more details on UnsafeShuffleManager's design, see SPARK-7081. + */ +private[spark] class UnsafeShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { + + if (!conf.getBoolean("spark.shuffle.spill", true)) { + logWarning( + "spark.shuffle.spill was set to false, but this is ignored by the tungsten-sort shuffle " + + "manager; its optimized shuffles will continue to spill to disk when necessary.") + } + + private[this] val sortShuffleManager: SortShuffleManager = new SortShuffleManager(conf) + private[this] val shufflesThatFellBackToSortShuffle = + Collections.newSetFromMap(new ConcurrentHashMap[Int, java.lang.Boolean]()) + private[this] val numMapsForShufflesThatUsedNewPath = new ConcurrentHashMap[Int, Int]() + + /** + * Register a shuffle with the manager and obtain a handle for it to pass to tasks. + */ + override def registerShuffle[K, V, C]( + shuffleId: Int, + numMaps: Int, + dependency: ShuffleDependency[K, V, C]): ShuffleHandle = { + if (UnsafeShuffleManager.canUseUnsafeShuffle(dependency)) { + new UnsafeShuffleHandle[K, V]( + shuffleId, numMaps, dependency.asInstanceOf[ShuffleDependency[K, V, V]]) + } else { + new BaseShuffleHandle(shuffleId, numMaps, dependency) + } + } + + /** + * Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive). + * Called on executors by reduce tasks. + */ + override def getReader[K, C]( + handle: ShuffleHandle, + startPartition: Int, + endPartition: Int, + context: TaskContext): ShuffleReader[K, C] = { + sortShuffleManager.getReader(handle, startPartition, endPartition, context) + } + + /** Get a writer for a given partition. Called on executors by map tasks. */ + override def getWriter[K, V]( + handle: ShuffleHandle, + mapId: Int, + context: TaskContext): ShuffleWriter[K, V] = { + handle match { + case unsafeShuffleHandle: UnsafeShuffleHandle[K, V] => + numMapsForShufflesThatUsedNewPath.putIfAbsent(handle.shuffleId, unsafeShuffleHandle.numMaps) + val env = SparkEnv.get + new UnsafeShuffleWriter( + env.blockManager, + shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver], + context.taskMemoryManager(), + env.shuffleMemoryManager, + unsafeShuffleHandle, + mapId, + context, + env.conf) + case other => + shufflesThatFellBackToSortShuffle.add(handle.shuffleId) + sortShuffleManager.getWriter(handle, mapId, context) + } + } + + /** Remove a shuffle's metadata from the ShuffleManager. */ + override def unregisterShuffle(shuffleId: Int): Boolean = { + if (shufflesThatFellBackToSortShuffle.remove(shuffleId)) { + sortShuffleManager.unregisterShuffle(shuffleId) + } else { + Option(numMapsForShufflesThatUsedNewPath.remove(shuffleId)).foreach { numMaps => + (0 until numMaps).foreach { mapId => + shuffleBlockResolver.removeDataByMap(shuffleId, mapId) + } + } + true + } + } + + override val shuffleBlockResolver: IndexShuffleBlockResolver = { + sortShuffleManager.shuffleBlockResolver + } + + /** Shut down this ShuffleManager. */ + override def stop(): Unit = { + sortShuffleManager.stop() + } +} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala index 8bc4e205bc3c6..a33f22ef52687 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala @@ -86,16 +86,6 @@ private[spark] class DiskBlockObjectWriter( extends BlockObjectWriter(blockId) with Logging { - /** Intercepts write calls and tracks total time spent writing. Not thread safe. */ - private class TimeTrackingOutputStream(out: OutputStream) extends OutputStream { - override def write(i: Int): Unit = callWithTiming(out.write(i)) - override def write(b: Array[Byte]): Unit = callWithTiming(out.write(b)) - override def write(b: Array[Byte], off: Int, len: Int): Unit = { - callWithTiming(out.write(b, off, len)) - } - override def close(): Unit = out.close() - override def flush(): Unit = out.flush() - } /** The file channel, used for repositioning / truncating the file. */ private var channel: FileChannel = null @@ -136,7 +126,7 @@ private[spark] class DiskBlockObjectWriter( throw new IllegalStateException("Writer already closed. Cannot be reopened.") } fos = new FileOutputStream(file, true) - ts = new TimeTrackingOutputStream(fos) + ts = new TimeTrackingOutputStream(writeMetrics, fos) channel = fos.getChannel() bs = compressStream(new BufferedOutputStream(ts, bufferSize)) objOut = serializerInstance.serializeStream(bs) @@ -150,9 +140,9 @@ private[spark] class DiskBlockObjectWriter( if (syncWrites) { // Force outstanding writes to disk and track how long it takes objOut.flush() - callWithTiming { - fos.getFD.sync() - } + val start = System.nanoTime() + fos.getFD.sync() + writeMetrics.incShuffleWriteTime(System.nanoTime() - start) } } { objOut.close() @@ -251,12 +241,6 @@ private[spark] class DiskBlockObjectWriter( reportedPosition = pos } - private def callWithTiming(f: => Unit) = { - val start = System.nanoTime() - f - writeMetrics.incShuffleWriteTime(System.nanoTime() - start) - } - // For testing private[spark] override def flush() { objOut.flush() diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index b850973145077..df2d6ad3b41a4 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -90,7 +90,7 @@ class ExternalAppendOnlyMap[K, V, C]( // Number of bytes spilled in total private var _diskBytesSpilled = 0L - // Use getSizeAsKb (not bytes) to maintain backwards compatibility of on units are provided + // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided private val fileBufferSize = sparkConf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024 diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 7d5cf7b61e56a..3b9d14f9372b6 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -110,7 +110,7 @@ private[spark] class ExternalSorter[K, V, C]( private val conf = SparkEnv.get.conf private val spillingEnabled = conf.getBoolean("spark.shuffle.spill", true) - // Use getSizeAsKb (not bytes) to maintain backwards compatibility of on units are provided + // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided private val fileBufferSize = conf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024 private val transferToEnabled = conf.getBoolean("spark.file.transferTo", true) diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java new file mode 100644 index 0000000000000..db9e82759090a --- /dev/null +++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.unsafe; + +import org.junit.Test; +import static org.junit.Assert.*; + +import org.apache.spark.unsafe.memory.ExecutorMemoryManager; +import org.apache.spark.unsafe.memory.MemoryAllocator; +import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.memory.TaskMemoryManager; +import static org.apache.spark.shuffle.unsafe.PackedRecordPointer.*; + +public class PackedRecordPointerSuite { + + @Test + public void heap() { + final TaskMemoryManager memoryManager = + new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); + final MemoryBlock page0 = memoryManager.allocatePage(100); + final MemoryBlock page1 = memoryManager.allocatePage(100); + final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1, + page1.getBaseOffset() + 42); + PackedRecordPointer packedPointer = new PackedRecordPointer(); + packedPointer.set(PackedRecordPointer.packPointer(addressInPage1, 360)); + assertEquals(360, packedPointer.getPartitionId()); + final long recordPointer = packedPointer.getRecordPointer(); + assertEquals(1, TaskMemoryManager.decodePageNumber(recordPointer)); + assertEquals(page1.getBaseOffset() + 42, memoryManager.getOffsetInPage(recordPointer)); + assertEquals(addressInPage1, recordPointer); + memoryManager.cleanUpAllAllocatedMemory(); + } + + @Test + public void offHeap() { + final TaskMemoryManager memoryManager = + new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.UNSAFE)); + final MemoryBlock page0 = memoryManager.allocatePage(100); + final MemoryBlock page1 = memoryManager.allocatePage(100); + final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1, + page1.getBaseOffset() + 42); + PackedRecordPointer packedPointer = new PackedRecordPointer(); + packedPointer.set(PackedRecordPointer.packPointer(addressInPage1, 360)); + assertEquals(360, packedPointer.getPartitionId()); + final long recordPointer = packedPointer.getRecordPointer(); + assertEquals(1, TaskMemoryManager.decodePageNumber(recordPointer)); + assertEquals(page1.getBaseOffset() + 42, memoryManager.getOffsetInPage(recordPointer)); + assertEquals(addressInPage1, recordPointer); + memoryManager.cleanUpAllAllocatedMemory(); + } + + @Test + public void maximumPartitionIdCanBeEncoded() { + PackedRecordPointer packedPointer = new PackedRecordPointer(); + packedPointer.set(PackedRecordPointer.packPointer(0, MAXIMUM_PARTITION_ID)); + assertEquals(MAXIMUM_PARTITION_ID, packedPointer.getPartitionId()); + } + + @Test + public void partitionIdsGreaterThanMaximumPartitionIdWillOverflowOrTriggerError() { + PackedRecordPointer packedPointer = new PackedRecordPointer(); + try { + // Pointers greater than the maximum partition ID will overflow or trigger an assertion error + packedPointer.set(PackedRecordPointer.packPointer(0, MAXIMUM_PARTITION_ID + 1)); + assertFalse(MAXIMUM_PARTITION_ID + 1 == packedPointer.getPartitionId()); + } catch (AssertionError e ) { + // pass + } + } + + @Test + public void maximumOffsetInPageCanBeEncoded() { + PackedRecordPointer packedPointer = new PackedRecordPointer(); + long address = TaskMemoryManager.encodePageNumberAndOffset(0, MAXIMUM_PAGE_SIZE_BYTES - 1); + packedPointer.set(PackedRecordPointer.packPointer(address, 0)); + assertEquals(address, packedPointer.getRecordPointer()); + } + + @Test + public void offsetsPastMaxOffsetInPageWillOverflow() { + PackedRecordPointer packedPointer = new PackedRecordPointer(); + long address = TaskMemoryManager.encodePageNumberAndOffset(0, MAXIMUM_PAGE_SIZE_BYTES); + packedPointer.set(PackedRecordPointer.packPointer(address, 0)); + assertEquals(0, packedPointer.getRecordPointer()); + } +} diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java new file mode 100644 index 0000000000000..8fa72597db24d --- /dev/null +++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.unsafe; + +import java.util.Arrays; +import java.util.Random; + +import org.junit.Assert; +import org.junit.Test; + +import org.apache.spark.HashPartitioner; +import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.memory.ExecutorMemoryManager; +import org.apache.spark.unsafe.memory.MemoryAllocator; +import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.memory.TaskMemoryManager; + +public class UnsafeShuffleInMemorySorterSuite { + + private static String getStringFromDataPage(Object baseObject, long baseOffset, int strLength) { + final byte[] strBytes = new byte[strLength]; + PlatformDependent.copyMemory( + baseObject, + baseOffset, + strBytes, + PlatformDependent.BYTE_ARRAY_OFFSET, strLength); + return new String(strBytes); + } + + @Test + public void testSortingEmptyInput() { + final UnsafeShuffleInMemorySorter sorter = new UnsafeShuffleInMemorySorter(100); + final UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator iter = sorter.getSortedIterator(); + assert(!iter.hasNext()); + } + + @Test + public void testBasicSorting() throws Exception { + final String[] dataToSort = new String[] { + "Boba", + "Pearls", + "Tapioca", + "Taho", + "Condensed Milk", + "Jasmine", + "Milk Tea", + "Lychee", + "Mango" + }; + final TaskMemoryManager memoryManager = + new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); + final MemoryBlock dataPage = memoryManager.allocatePage(2048); + final Object baseObject = dataPage.getBaseObject(); + final UnsafeShuffleInMemorySorter sorter = new UnsafeShuffleInMemorySorter(4); + final HashPartitioner hashPartitioner = new HashPartitioner(4); + + // Write the records into the data page and store pointers into the sorter + long position = dataPage.getBaseOffset(); + for (String str : dataToSort) { + final long recordAddress = memoryManager.encodePageNumberAndOffset(dataPage, position); + final byte[] strBytes = str.getBytes("utf-8"); + PlatformDependent.UNSAFE.putInt(baseObject, position, strBytes.length); + position += 4; + PlatformDependent.copyMemory( + strBytes, + PlatformDependent.BYTE_ARRAY_OFFSET, + baseObject, + position, + strBytes.length); + position += strBytes.length; + sorter.insertRecord(recordAddress, hashPartitioner.getPartition(str)); + } + + // Sort the records + final UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator iter = sorter.getSortedIterator(); + int prevPartitionId = -1; + Arrays.sort(dataToSort); + for (int i = 0; i < dataToSort.length; i++) { + Assert.assertTrue(iter.hasNext()); + iter.loadNext(); + final int partitionId = iter.packedRecordPointer.getPartitionId(); + Assert.assertTrue(partitionId >= 0 && partitionId <= 3); + Assert.assertTrue("Partition id " + partitionId + " should be >= prev id " + prevPartitionId, + partitionId >= prevPartitionId); + final long recordAddress = iter.packedRecordPointer.getRecordPointer(); + final int recordLength = PlatformDependent.UNSAFE.getInt( + memoryManager.getPage(recordAddress), memoryManager.getOffsetInPage(recordAddress)); + final String str = getStringFromDataPage( + memoryManager.getPage(recordAddress), + memoryManager.getOffsetInPage(recordAddress) + 4, // skip over record length + recordLength); + Assert.assertTrue(Arrays.binarySearch(dataToSort, str) != -1); + } + Assert.assertFalse(iter.hasNext()); + } + + @Test + public void testSortingManyNumbers() throws Exception { + UnsafeShuffleInMemorySorter sorter = new UnsafeShuffleInMemorySorter(4); + int[] numbersToSort = new int[128000]; + Random random = new Random(16); + for (int i = 0; i < numbersToSort.length; i++) { + numbersToSort[i] = random.nextInt(PackedRecordPointer.MAXIMUM_PARTITION_ID + 1); + sorter.insertRecord(0, numbersToSort[i]); + } + Arrays.sort(numbersToSort); + int[] sorterResult = new int[numbersToSort.length]; + UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator iter = sorter.getSortedIterator(); + int j = 0; + while (iter.hasNext()) { + iter.loadNext(); + sorterResult[j] = iter.packedRecordPointer.getPartitionId(); + j += 1; + } + Assert.assertArrayEquals(numbersToSort, sorterResult); + } +} diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java new file mode 100644 index 0000000000000..730d265c87f88 --- /dev/null +++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java @@ -0,0 +1,527 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.unsafe; + +import java.io.*; +import java.nio.ByteBuffer; +import java.util.*; + +import scala.*; +import scala.collection.Iterator; +import scala.reflect.ClassTag; +import scala.runtime.AbstractFunction1; + +import com.google.common.collect.HashMultiset; +import com.google.common.io.ByteStreams; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.lessThan; +import static org.junit.Assert.*; +import static org.mockito.AdditionalAnswers.returnsFirstArg; +import static org.mockito.Answers.RETURNS_SMART_NULLS; +import static org.mockito.Mockito.*; + +import org.apache.spark.*; +import org.apache.spark.io.CompressionCodec$; +import org.apache.spark.io.LZ4CompressionCodec; +import org.apache.spark.io.LZFCompressionCodec; +import org.apache.spark.io.SnappyCompressionCodec; +import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.executor.TaskMetrics; +import org.apache.spark.network.util.LimitedInputStream; +import org.apache.spark.serializer.*; +import org.apache.spark.scheduler.MapStatus; +import org.apache.spark.shuffle.IndexShuffleBlockResolver; +import org.apache.spark.shuffle.ShuffleMemoryManager; +import org.apache.spark.storage.*; +import org.apache.spark.unsafe.memory.ExecutorMemoryManager; +import org.apache.spark.unsafe.memory.MemoryAllocator; +import org.apache.spark.unsafe.memory.TaskMemoryManager; +import org.apache.spark.util.Utils; + +public class UnsafeShuffleWriterSuite { + + static final int NUM_PARTITITONS = 4; + final TaskMemoryManager taskMemoryManager = + new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); + final HashPartitioner hashPartitioner = new HashPartitioner(NUM_PARTITITONS); + File mergedOutputFile; + File tempDir; + long[] partitionSizesInMergedFile; + final LinkedList spillFilesCreated = new LinkedList(); + SparkConf conf; + final Serializer serializer = new KryoSerializer(new SparkConf()); + TaskMetrics taskMetrics; + + @Mock(answer = RETURNS_SMART_NULLS) ShuffleMemoryManager shuffleMemoryManager; + @Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager; + @Mock(answer = RETURNS_SMART_NULLS) IndexShuffleBlockResolver shuffleBlockResolver; + @Mock(answer = RETURNS_SMART_NULLS) DiskBlockManager diskBlockManager; + @Mock(answer = RETURNS_SMART_NULLS) TaskContext taskContext; + @Mock(answer = RETURNS_SMART_NULLS) ShuffleDependency shuffleDep; + + private final class CompressStream extends AbstractFunction1 { + @Override + public OutputStream apply(OutputStream stream) { + if (conf.getBoolean("spark.shuffle.compress", true)) { + return CompressionCodec$.MODULE$.createCodec(conf).compressedOutputStream(stream); + } else { + return stream; + } + } + } + + @After + public void tearDown() { + Utils.deleteRecursively(tempDir); + final long leakedMemory = taskMemoryManager.cleanUpAllAllocatedMemory(); + if (leakedMemory != 0) { + fail("Test leaked " + leakedMemory + " bytes of managed memory"); + } + } + + @Before + @SuppressWarnings("unchecked") + public void setUp() throws IOException { + MockitoAnnotations.initMocks(this); + tempDir = Utils.createTempDir("test", "test"); + mergedOutputFile = File.createTempFile("mergedoutput", "", tempDir); + partitionSizesInMergedFile = null; + spillFilesCreated.clear(); + conf = new SparkConf(); + taskMetrics = new TaskMetrics(); + + when(shuffleMemoryManager.tryToAcquire(anyLong())).then(returnsFirstArg()); + + when(blockManager.diskBlockManager()).thenReturn(diskBlockManager); + when(blockManager.getDiskWriter( + any(BlockId.class), + any(File.class), + any(SerializerInstance.class), + anyInt(), + any(ShuffleWriteMetrics.class))).thenAnswer(new Answer() { + @Override + public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Throwable { + Object[] args = invocationOnMock.getArguments(); + + return new DiskBlockObjectWriter( + (BlockId) args[0], + (File) args[1], + (SerializerInstance) args[2], + (Integer) args[3], + new CompressStream(), + false, + (ShuffleWriteMetrics) args[4] + ); + } + }); + when(blockManager.wrapForCompression(any(BlockId.class), any(InputStream.class))).thenAnswer( + new Answer() { + @Override + public InputStream answer(InvocationOnMock invocation) throws Throwable { + assert (invocation.getArguments()[0] instanceof TempShuffleBlockId); + InputStream is = (InputStream) invocation.getArguments()[1]; + if (conf.getBoolean("spark.shuffle.compress", true)) { + return CompressionCodec$.MODULE$.createCodec(conf).compressedInputStream(is); + } else { + return is; + } + } + } + ); + + when(blockManager.wrapForCompression(any(BlockId.class), any(OutputStream.class))).thenAnswer( + new Answer() { + @Override + public OutputStream answer(InvocationOnMock invocation) throws Throwable { + assert (invocation.getArguments()[0] instanceof TempShuffleBlockId); + OutputStream os = (OutputStream) invocation.getArguments()[1]; + if (conf.getBoolean("spark.shuffle.compress", true)) { + return CompressionCodec$.MODULE$.createCodec(conf).compressedOutputStream(os); + } else { + return os; + } + } + } + ); + + when(shuffleBlockResolver.getDataFile(anyInt(), anyInt())).thenReturn(mergedOutputFile); + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock invocationOnMock) throws Throwable { + partitionSizesInMergedFile = (long[]) invocationOnMock.getArguments()[2]; + return null; + } + }).when(shuffleBlockResolver).writeIndexFile(anyInt(), anyInt(), any(long[].class)); + + when(diskBlockManager.createTempShuffleBlock()).thenAnswer( + new Answer>() { + @Override + public Tuple2 answer( + InvocationOnMock invocationOnMock) throws Throwable { + TempShuffleBlockId blockId = new TempShuffleBlockId(UUID.randomUUID()); + File file = File.createTempFile("spillFile", ".spill", tempDir); + spillFilesCreated.add(file); + return Tuple2$.MODULE$.apply(blockId, file); + } + }); + + when(taskContext.taskMetrics()).thenReturn(taskMetrics); + + when(shuffleDep.serializer()).thenReturn(Option.apply(serializer)); + when(shuffleDep.partitioner()).thenReturn(hashPartitioner); + } + + private UnsafeShuffleWriter createWriter( + boolean transferToEnabled) throws IOException { + conf.set("spark.file.transferTo", String.valueOf(transferToEnabled)); + return new UnsafeShuffleWriter( + blockManager, + shuffleBlockResolver, + taskMemoryManager, + shuffleMemoryManager, + new UnsafeShuffleHandle(0, 1, shuffleDep), + 0, // map id + taskContext, + conf + ); + } + + private void assertSpillFilesWereCleanedUp() { + for (File spillFile : spillFilesCreated) { + assertFalse("Spill file " + spillFile.getPath() + " was not cleaned up", + spillFile.exists()); + } + } + + private List> readRecordsFromFile() throws IOException { + final ArrayList> recordsList = new ArrayList>(); + long startOffset = 0; + for (int i = 0; i < NUM_PARTITITONS; i++) { + final long partitionSize = partitionSizesInMergedFile[i]; + if (partitionSize > 0) { + InputStream in = new FileInputStream(mergedOutputFile); + ByteStreams.skipFully(in, startOffset); + in = new LimitedInputStream(in, partitionSize); + if (conf.getBoolean("spark.shuffle.compress", true)) { + in = CompressionCodec$.MODULE$.createCodec(conf).compressedInputStream(in); + } + DeserializationStream recordsStream = serializer.newInstance().deserializeStream(in); + Iterator> records = recordsStream.asKeyValueIterator(); + while (records.hasNext()) { + Tuple2 record = records.next(); + assertEquals(i, hashPartitioner.getPartition(record._1())); + recordsList.add(record); + } + recordsStream.close(); + startOffset += partitionSize; + } + } + return recordsList; + } + + @Test(expected=IllegalStateException.class) + public void mustCallWriteBeforeSuccessfulStop() throws IOException { + createWriter(false).stop(true); + } + + @Test + public void doNotNeedToCallWriteBeforeUnsuccessfulStop() throws IOException { + createWriter(false).stop(false); + } + + @Test + public void writeEmptyIterator() throws Exception { + final UnsafeShuffleWriter writer = createWriter(true); + writer.write(Collections.>emptyIterator()); + final Option mapStatus = writer.stop(true); + assertTrue(mapStatus.isDefined()); + assertTrue(mergedOutputFile.exists()); + assertArrayEquals(new long[NUM_PARTITITONS], partitionSizesInMergedFile); + assertEquals(0, taskMetrics.shuffleWriteMetrics().get().shuffleRecordsWritten()); + assertEquals(0, taskMetrics.shuffleWriteMetrics().get().shuffleBytesWritten()); + assertEquals(0, taskMetrics.diskBytesSpilled()); + assertEquals(0, taskMetrics.memoryBytesSpilled()); + } + + @Test + public void writeWithoutSpilling() throws Exception { + // In this example, each partition should have exactly one record: + final ArrayList> dataToWrite = + new ArrayList>(); + for (int i = 0; i < NUM_PARTITITONS; i++) { + dataToWrite.add(new Tuple2(i, i)); + } + final UnsafeShuffleWriter writer = createWriter(true); + writer.write(dataToWrite.iterator()); + final Option mapStatus = writer.stop(true); + assertTrue(mapStatus.isDefined()); + assertTrue(mergedOutputFile.exists()); + + long sumOfPartitionSizes = 0; + for (long size: partitionSizesInMergedFile) { + // All partitions should be the same size: + assertEquals(partitionSizesInMergedFile[0], size); + sumOfPartitionSizes += size; + } + assertEquals(mergedOutputFile.length(), sumOfPartitionSizes); + assertEquals( + HashMultiset.create(dataToWrite), + HashMultiset.create(readRecordsFromFile())); + assertSpillFilesWereCleanedUp(); + ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics().get(); + assertEquals(dataToWrite.size(), shuffleWriteMetrics.shuffleRecordsWritten()); + assertEquals(0, taskMetrics.diskBytesSpilled()); + assertEquals(0, taskMetrics.memoryBytesSpilled()); + assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.shuffleBytesWritten()); + } + + private void testMergingSpills( + boolean transferToEnabled, + String compressionCodecName) throws IOException { + if (compressionCodecName != null) { + conf.set("spark.shuffle.compress", "true"); + conf.set("spark.io.compression.codec", compressionCodecName); + } else { + conf.set("spark.shuffle.compress", "false"); + } + final UnsafeShuffleWriter writer = createWriter(transferToEnabled); + final ArrayList> dataToWrite = + new ArrayList>(); + for (int i : new int[] { 1, 2, 3, 4, 4, 2 }) { + dataToWrite.add(new Tuple2(i, i)); + } + writer.insertRecordIntoSorter(dataToWrite.get(0)); + writer.insertRecordIntoSorter(dataToWrite.get(1)); + writer.insertRecordIntoSorter(dataToWrite.get(2)); + writer.insertRecordIntoSorter(dataToWrite.get(3)); + writer.forceSorterToSpill(); + writer.insertRecordIntoSorter(dataToWrite.get(4)); + writer.insertRecordIntoSorter(dataToWrite.get(5)); + writer.closeAndWriteOutput(); + final Option mapStatus = writer.stop(true); + assertTrue(mapStatus.isDefined()); + assertTrue(mergedOutputFile.exists()); + assertEquals(2, spillFilesCreated.size()); + + long sumOfPartitionSizes = 0; + for (long size: partitionSizesInMergedFile) { + sumOfPartitionSizes += size; + } + assertEquals(sumOfPartitionSizes, mergedOutputFile.length()); + + assertEquals( + HashMultiset.create(dataToWrite), + HashMultiset.create(readRecordsFromFile())); + assertSpillFilesWereCleanedUp(); + ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics().get(); + assertEquals(dataToWrite.size(), shuffleWriteMetrics.shuffleRecordsWritten()); + assertThat(taskMetrics.diskBytesSpilled(), greaterThan(0L)); + assertThat(taskMetrics.diskBytesSpilled(), lessThan(mergedOutputFile.length())); + assertThat(taskMetrics.memoryBytesSpilled(), greaterThan(0L)); + assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.shuffleBytesWritten()); + } + + @Test + public void mergeSpillsWithTransferToAndLZF() throws Exception { + testMergingSpills(true, LZFCompressionCodec.class.getName()); + } + + @Test + public void mergeSpillsWithFileStreamAndLZF() throws Exception { + testMergingSpills(false, LZFCompressionCodec.class.getName()); + } + + @Test + public void mergeSpillsWithTransferToAndLZ4() throws Exception { + testMergingSpills(true, LZ4CompressionCodec.class.getName()); + } + + @Test + public void mergeSpillsWithFileStreamAndLZ4() throws Exception { + testMergingSpills(false, LZ4CompressionCodec.class.getName()); + } + + @Test + public void mergeSpillsWithTransferToAndSnappy() throws Exception { + testMergingSpills(true, SnappyCompressionCodec.class.getName()); + } + + @Test + public void mergeSpillsWithFileStreamAndSnappy() throws Exception { + testMergingSpills(false, SnappyCompressionCodec.class.getName()); + } + + @Test + public void mergeSpillsWithTransferToAndNoCompression() throws Exception { + testMergingSpills(true, null); + } + + @Test + public void mergeSpillsWithFileStreamAndNoCompression() throws Exception { + testMergingSpills(false, null); + } + + @Test + public void writeEnoughDataToTriggerSpill() throws Exception { + when(shuffleMemoryManager.tryToAcquire(anyLong())) + .then(returnsFirstArg()) // Allocate initial sort buffer + .then(returnsFirstArg()) // Allocate initial data page + .thenReturn(0L) // Deny request to allocate new data page + .then(returnsFirstArg()); // Grant new sort buffer and data page. + final UnsafeShuffleWriter writer = createWriter(false); + final ArrayList> dataToWrite = new ArrayList>(); + final byte[] bigByteArray = new byte[PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES / 128]; + for (int i = 0; i < 128 + 1; i++) { + dataToWrite.add(new Tuple2(i, bigByteArray)); + } + writer.write(dataToWrite.iterator()); + verify(shuffleMemoryManager, times(5)).tryToAcquire(anyLong()); + assertEquals(2, spillFilesCreated.size()); + writer.stop(true); + readRecordsFromFile(); + assertSpillFilesWereCleanedUp(); + ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics().get(); + assertEquals(dataToWrite.size(), shuffleWriteMetrics.shuffleRecordsWritten()); + assertThat(taskMetrics.diskBytesSpilled(), greaterThan(0L)); + assertThat(taskMetrics.diskBytesSpilled(), lessThan(mergedOutputFile.length())); + assertThat(taskMetrics.memoryBytesSpilled(), greaterThan(0L)); + assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.shuffleBytesWritten()); + } + + @Test + public void writeEnoughRecordsToTriggerSortBufferExpansionAndSpill() throws Exception { + when(shuffleMemoryManager.tryToAcquire(anyLong())) + .then(returnsFirstArg()) // Allocate initial sort buffer + .then(returnsFirstArg()) // Allocate initial data page + .thenReturn(0L) // Deny request to grow sort buffer + .then(returnsFirstArg()); // Grant new sort buffer and data page. + final UnsafeShuffleWriter writer = createWriter(false); + final ArrayList> dataToWrite = new ArrayList>(); + for (int i = 0; i < UnsafeShuffleWriter.INITIAL_SORT_BUFFER_SIZE; i++) { + dataToWrite.add(new Tuple2(i, i)); + } + writer.write(dataToWrite.iterator()); + verify(shuffleMemoryManager, times(5)).tryToAcquire(anyLong()); + assertEquals(2, spillFilesCreated.size()); + writer.stop(true); + readRecordsFromFile(); + assertSpillFilesWereCleanedUp(); + ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics().get(); + assertEquals(dataToWrite.size(), shuffleWriteMetrics.shuffleRecordsWritten()); + assertThat(taskMetrics.diskBytesSpilled(), greaterThan(0L)); + assertThat(taskMetrics.diskBytesSpilled(), lessThan(mergedOutputFile.length())); + assertThat(taskMetrics.memoryBytesSpilled(), greaterThan(0L)); + assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.shuffleBytesWritten()); + } + + @Test + public void writeRecordsThatAreBiggerThanDiskWriteBufferSize() throws Exception { + final UnsafeShuffleWriter writer = createWriter(false); + final ArrayList> dataToWrite = + new ArrayList>(); + final byte[] bytes = new byte[(int) (UnsafeShuffleExternalSorter.DISK_WRITE_BUFFER_SIZE * 2.5)]; + new Random(42).nextBytes(bytes); + dataToWrite.add(new Tuple2(1, ByteBuffer.wrap(bytes))); + writer.write(dataToWrite.iterator()); + writer.stop(true); + assertEquals( + HashMultiset.create(dataToWrite), + HashMultiset.create(readRecordsFromFile())); + assertSpillFilesWereCleanedUp(); + } + + @Test + public void writeRecordsThatAreBiggerThanMaxRecordSize() throws Exception { + // Use a custom serializer so that we have exact control over the size of serialized data. + final Serializer byteArraySerializer = new Serializer() { + @Override + public SerializerInstance newInstance() { + return new SerializerInstance() { + @Override + public SerializationStream serializeStream(final OutputStream s) { + return new SerializationStream() { + @Override + public void flush() { } + + @Override + public SerializationStream writeObject(T t, ClassTag ev1) { + byte[] bytes = (byte[]) t; + try { + s.write(bytes); + } catch (IOException e) { + throw new RuntimeException(e); + } + return this; + } + + @Override + public void close() { } + }; + } + public ByteBuffer serialize(T t, ClassTag ev1) { return null; } + public DeserializationStream deserializeStream(InputStream s) { return null; } + public T deserialize(ByteBuffer b, ClassLoader l, ClassTag ev1) { return null; } + public T deserialize(ByteBuffer bytes, ClassTag ev1) { return null; } + }; + } + }; + when(shuffleDep.serializer()).thenReturn(Option.apply(byteArraySerializer)); + final UnsafeShuffleWriter writer = createWriter(false); + // Insert a record and force a spill so that there's something to clean up: + writer.insertRecordIntoSorter(new Tuple2(new byte[1], new byte[1])); + writer.forceSorterToSpill(); + // We should be able to write a record that's right _at_ the max record size + final byte[] atMaxRecordSize = new byte[UnsafeShuffleExternalSorter.MAX_RECORD_SIZE]; + new Random(42).nextBytes(atMaxRecordSize); + writer.insertRecordIntoSorter(new Tuple2(new byte[0], atMaxRecordSize)); + writer.forceSorterToSpill(); + // Inserting a record that's larger than the max record size should fail: + final byte[] exceedsMaxRecordSize = new byte[UnsafeShuffleExternalSorter.MAX_RECORD_SIZE + 1]; + new Random(42).nextBytes(exceedsMaxRecordSize); + Product2 hugeRecord = + new Tuple2(new byte[0], exceedsMaxRecordSize); + try { + // Here, we write through the public `write()` interface instead of the test-only + // `insertRecordIntoSorter` interface: + writer.write(Collections.singletonList(hugeRecord).iterator()); + fail("Expected exception to be thrown"); + } catch (IOException e) { + // Pass + } + assertSpillFilesWereCleanedUp(); + } + + @Test + public void spillFilesAreDeletedWhenStoppingAfterError() throws IOException { + final UnsafeShuffleWriter writer = createWriter(false); + writer.insertRecordIntoSorter(new Tuple2(1, 1)); + writer.insertRecordIntoSorter(new Tuple2(2, 2)); + writer.forceSorterToSpill(); + writer.insertRecordIntoSorter(new Tuple2(2, 2)); + writer.stop(false); + assertSpillFilesWereCleanedUp(); + } +} diff --git a/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala b/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala index 8c6035fb367fe..cf6a143537889 100644 --- a/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala +++ b/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.io import java.io.{ByteArrayInputStream, ByteArrayOutputStream} +import com.google.common.io.ByteStreams import org.scalatest.FunSuite import org.apache.spark.SparkConf @@ -62,6 +63,14 @@ class CompressionCodecSuite extends FunSuite { testCodec(codec) } + test("lz4 does not support concatenation of serialized streams") { + val codec = CompressionCodec.createCodec(conf, classOf[LZ4CompressionCodec].getName) + assert(codec.getClass === classOf[LZ4CompressionCodec]) + intercept[Exception] { + testConcatenationOfSerializedStreams(codec) + } + } + test("lzf compression codec") { val codec = CompressionCodec.createCodec(conf, classOf[LZFCompressionCodec].getName) assert(codec.getClass === classOf[LZFCompressionCodec]) @@ -74,6 +83,12 @@ class CompressionCodecSuite extends FunSuite { testCodec(codec) } + test("lzf supports concatenation of serialized streams") { + val codec = CompressionCodec.createCodec(conf, classOf[LZFCompressionCodec].getName) + assert(codec.getClass === classOf[LZFCompressionCodec]) + testConcatenationOfSerializedStreams(codec) + } + test("snappy compression codec") { val codec = CompressionCodec.createCodec(conf, classOf[SnappyCompressionCodec].getName) assert(codec.getClass === classOf[SnappyCompressionCodec]) @@ -86,9 +101,38 @@ class CompressionCodecSuite extends FunSuite { testCodec(codec) } + test("snappy does not support concatenation of serialized streams") { + val codec = CompressionCodec.createCodec(conf, classOf[SnappyCompressionCodec].getName) + assert(codec.getClass === classOf[SnappyCompressionCodec]) + intercept[Exception] { + testConcatenationOfSerializedStreams(codec) + } + } + test("bad compression codec") { intercept[IllegalArgumentException] { CompressionCodec.createCodec(conf, "foobar") } } + + private def testConcatenationOfSerializedStreams(codec: CompressionCodec): Unit = { + val bytes1: Array[Byte] = { + val baos = new ByteArrayOutputStream() + val out = codec.compressedOutputStream(baos) + (0 to 64).foreach(out.write) + out.close() + baos.toByteArray + } + val bytes2: Array[Byte] = { + val baos = new ByteArrayOutputStream() + val out = codec.compressedOutputStream(baos) + (65 to 127).foreach(out.write) + out.close() + baos.toByteArray + } + val concatenatedBytes = codec.compressedInputStream(new ByteArrayInputStream(bytes1 ++ bytes2)) + val decompressed: Array[Byte] = new Array[Byte](128) + ByteStreams.readFully(concatenatedBytes, decompressed) + assert(decompressed.toSeq === (0 to 127)) + } } diff --git a/core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala new file mode 100644 index 0000000000000..ed4d8ce632e16 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.serializer + +import org.apache.spark.SparkConf +import org.scalatest.FunSuite + +class JavaSerializerSuite extends FunSuite { + test("JavaSerializer instances are serializable") { + val serializer = new JavaSerializer(new SparkConf()) + val instance = serializer.newInstance() + instance.deserialize[JavaSerializer](instance.serialize(serializer)) + } +} diff --git a/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManagerSuite.scala new file mode 100644 index 0000000000000..49a04a2a45280 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManagerSuite.scala @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.unsafe + +import org.mockito.Mockito._ +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer +import org.scalatest.{FunSuite, Matchers} + +import org.apache.spark._ +import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, Serializer} + +/** + * Tests for the fallback logic in UnsafeShuffleManager. Actual tests of shuffling data are + * performed in other suites. + */ +class UnsafeShuffleManagerSuite extends FunSuite with Matchers { + + import UnsafeShuffleManager.canUseUnsafeShuffle + + private class RuntimeExceptionAnswer extends Answer[Object] { + override def answer(invocation: InvocationOnMock): Object = { + throw new RuntimeException("Called non-stubbed method, " + invocation.getMethod.getName) + } + } + + private def shuffleDep( + partitioner: Partitioner, + serializer: Option[Serializer], + keyOrdering: Option[Ordering[Any]], + aggregator: Option[Aggregator[Any, Any, Any]], + mapSideCombine: Boolean): ShuffleDependency[Any, Any, Any] = { + val dep = mock(classOf[ShuffleDependency[Any, Any, Any]], new RuntimeExceptionAnswer()) + doReturn(0).when(dep).shuffleId + doReturn(partitioner).when(dep).partitioner + doReturn(serializer).when(dep).serializer + doReturn(keyOrdering).when(dep).keyOrdering + doReturn(aggregator).when(dep).aggregator + doReturn(mapSideCombine).when(dep).mapSideCombine + dep + } + + test("supported shuffle dependencies") { + val kryo = Some(new KryoSerializer(new SparkConf())) + + assert(canUseUnsafeShuffle(shuffleDep( + partitioner = new HashPartitioner(2), + serializer = kryo, + keyOrdering = None, + aggregator = None, + mapSideCombine = false + ))) + + val rangePartitioner = mock(classOf[RangePartitioner[Any, Any]]) + when(rangePartitioner.numPartitions).thenReturn(2) + assert(canUseUnsafeShuffle(shuffleDep( + partitioner = rangePartitioner, + serializer = kryo, + keyOrdering = None, + aggregator = None, + mapSideCombine = false + ))) + + } + + test("unsupported shuffle dependencies") { + val kryo = Some(new KryoSerializer(new SparkConf())) + val java = Some(new JavaSerializer(new SparkConf())) + + // We only support serializers that support object relocation + assert(!canUseUnsafeShuffle(shuffleDep( + partitioner = new HashPartitioner(2), + serializer = java, + keyOrdering = None, + aggregator = None, + mapSideCombine = false + ))) + + // We do not support shuffles with more than 16 million output partitions + assert(!canUseUnsafeShuffle(shuffleDep( + partitioner = new HashPartitioner(UnsafeShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS + 1), + serializer = kryo, + keyOrdering = None, + aggregator = None, + mapSideCombine = false + ))) + + // We do not support shuffles that perform any kind of aggregation or sorting of keys + assert(!canUseUnsafeShuffle(shuffleDep( + partitioner = new HashPartitioner(2), + serializer = kryo, + keyOrdering = Some(mock(classOf[Ordering[Any]])), + aggregator = None, + mapSideCombine = false + ))) + assert(!canUseUnsafeShuffle(shuffleDep( + partitioner = new HashPartitioner(2), + serializer = kryo, + keyOrdering = None, + aggregator = Some(mock(classOf[Aggregator[Any, Any, Any]])), + mapSideCombine = false + ))) + // We do not support shuffles that perform any kind of aggregation or sorting of keys + assert(!canUseUnsafeShuffle(shuffleDep( + partitioner = new HashPartitioner(2), + serializer = kryo, + keyOrdering = Some(mock(classOf[Ordering[Any]])), + aggregator = Some(mock(classOf[Aggregator[Any, Any, Any]])), + mapSideCombine = true + ))) + } + +} diff --git a/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala new file mode 100644 index 0000000000000..6351539e91e97 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.unsafe + +import java.io.File + +import scala.collection.JavaConverters._ + +import org.apache.commons.io.FileUtils +import org.apache.commons.io.filefilter.TrueFileFilter +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.{HashPartitioner, ShuffleDependency, SparkContext, ShuffleSuite} +import org.apache.spark.rdd.ShuffledRDD +import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} +import org.apache.spark.util.Utils + +class UnsafeShuffleSuite extends ShuffleSuite with BeforeAndAfterAll { + + // This test suite should run all tests in ShuffleSuite with unsafe-based shuffle. + + override def beforeAll() { + conf.set("spark.shuffle.manager", "tungsten-sort") + // UnsafeShuffleManager requires at least 128 MB of memory per task in order to be able to sort + // shuffle records. + conf.set("spark.shuffle.memoryFraction", "0.5") + } + + test("UnsafeShuffleManager properly cleans up files for shuffles that use the new shuffle path") { + val tmpDir = Utils.createTempDir() + try { + val myConf = conf.clone() + .set("spark.local.dir", tmpDir.getAbsolutePath) + sc = new SparkContext("local", "test", myConf) + // Create a shuffled RDD and verify that it will actually use the new UnsafeShuffle path + val rdd = sc.parallelize(1 to 10, 1).map(x => (x, x)) + val shuffledRdd = new ShuffledRDD[Int, Int, Int](rdd, new HashPartitioner(4)) + .setSerializer(new KryoSerializer(myConf)) + val shuffleDep = shuffledRdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]] + assert(UnsafeShuffleManager.canUseUnsafeShuffle(shuffleDep)) + def getAllFiles: Set[File] = + FileUtils.listFiles(tmpDir, TrueFileFilter.INSTANCE, TrueFileFilter.INSTANCE).asScala.toSet + val filesBeforeShuffle = getAllFiles + // Force the shuffle to be performed + shuffledRdd.count() + // Ensure that the shuffle actually created files that will need to be cleaned up + val filesCreatedByShuffle = getAllFiles -- filesBeforeShuffle + filesCreatedByShuffle.map(_.getName) should be + Set("shuffle_0_0_0.data", "shuffle_0_0_0.index") + // Check that the cleanup actually removes the files + sc.env.blockManager.master.removeShuffle(0, blocking = true) + for (file <- filesCreatedByShuffle) { + assert (!file.exists(), s"Shuffle file $file was not cleaned up") + } + } finally { + Utils.deleteRecursively(tmpDir) + } + } + + test("UnsafeShuffleManager properly cleans up files for shuffles that use the old shuffle path") { + val tmpDir = Utils.createTempDir() + try { + val myConf = conf.clone() + .set("spark.local.dir", tmpDir.getAbsolutePath) + sc = new SparkContext("local", "test", myConf) + // Create a shuffled RDD and verify that it will actually use the old SortShuffle path + val rdd = sc.parallelize(1 to 10, 1).map(x => (x, x)) + val shuffledRdd = new ShuffledRDD[Int, Int, Int](rdd, new HashPartitioner(4)) + .setSerializer(new JavaSerializer(myConf)) + val shuffleDep = shuffledRdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]] + assert(!UnsafeShuffleManager.canUseUnsafeShuffle(shuffleDep)) + def getAllFiles: Set[File] = + FileUtils.listFiles(tmpDir, TrueFileFilter.INSTANCE, TrueFileFilter.INSTANCE).asScala.toSet + val filesBeforeShuffle = getAllFiles + // Force the shuffle to be performed + shuffledRdd.count() + // Ensure that the shuffle actually created files that will need to be cleaned up + val filesCreatedByShuffle = getAllFiles -- filesBeforeShuffle + filesCreatedByShuffle.map(_.getName) should be + Set("shuffle_0_0_0.data", "shuffle_0_0_0.index") + // Check that the cleanup actually removes the files + sc.env.blockManager.master.removeShuffle(0, blocking = true) + for (file <- filesCreatedByShuffle) { + assert (!file.exists(), s"Shuffle file $file was not cleaned up") + } + } finally { + Utils.deleteRecursively(tmpDir) + } + } +} diff --git a/pom.xml b/pom.xml index cf9279ea5a2a6..564a443466e5a 100644 --- a/pom.xml +++ b/pom.xml @@ -669,7 +669,7 @@ org.mockito mockito-all - 1.9.0 + 1.9.5 test @@ -684,6 +684,18 @@ 4.10 test + + org.hamcrest + hamcrest-core + 1.3 + test + + + org.hamcrest + hamcrest-library + 1.3 + test + com.novocode junit-interface diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index fba7290dcb0b5..487062a31f77f 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -131,6 +131,12 @@ object MimaExcludes { // SPARK-7530 Added StreamingContext.getState() ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.streaming.StreamingContext.state_=") + ) ++ Seq( + // SPARK-7081 changed ShuffleWriter from a trait to an abstract class and removed some + // unnecessary type bounds in order to fix some compiler warnings that occurred when + // implementing this interface in Java. Note that ShuffleWriter is private[spark]. + ProblemFilters.exclude[IncompatibleTemplateDefProblem]( + "org.apache.spark.shuffle.ShuffleWriter") ) case v if v.startsWith("1.3") => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index c3d2c7019a54a..3e46596ecf6ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -17,17 +17,18 @@ package org.apache.spark.sql.execution -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.{HashPartitioner, Partitioner, RangePartitioner, SparkEnv} +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.{RDD, ShuffledRDD} import org.apache.spark.serializer.Serializer -import org.apache.spark.sql.{SQLContext, Row} +import org.apache.spark.shuffle.sort.SortShuffleManager +import org.apache.spark.shuffle.unsafe.UnsafeShuffleManager import org.apache.spark.sql.catalyst.errors.attachTree import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.{SQLContext, Row} import org.apache.spark.util.MutablePair object Exchange { @@ -85,7 +86,9 @@ case class Exchange( // corner-cases where a partitioner constructed with `numPartitions` partitions may output // fewer partitions (like RangePartitioner, for example). val conf = child.sqlContext.sparkContext.conf - val sortBasedShuffleOn = SparkEnv.get.shuffleManager.isInstanceOf[SortShuffleManager] + val shuffleManager = SparkEnv.get.shuffleManager + val sortBasedShuffleOn = shuffleManager.isInstanceOf[SortShuffleManager] || + shuffleManager.isInstanceOf[UnsafeShuffleManager] val bypassMergeThreshold = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) val serializeMapOutputs = conf.getBoolean("spark.shuffle.sort.serializeMapOutputs", true) if (newOrdering.nonEmpty) { @@ -93,11 +96,11 @@ case class Exchange( // which requires a defensive copy. true } else if (sortBasedShuffleOn) { - // Spark's sort-based shuffle also uses `ExternalSorter` to buffer records in memory. - // However, there are two special cases where we can avoid the copy, described below: - if (partitioner.numPartitions <= bypassMergeThreshold) { - // If the number of output partitions is sufficiently small, then Spark will fall back to - // the old hash-based shuffle write path which doesn't buffer deserialized records. + val bypassIsSupported = SparkEnv.get.shuffleManager.isInstanceOf[SortShuffleManager] + if (bypassIsSupported && partitioner.numPartitions <= bypassMergeThreshold) { + // If we're using the original SortShuffleManager and the number of output partitions is + // sufficiently small, then Spark will fall back to the hash-based shuffle write path, which + // doesn't buffer deserialized records. // Note that we'll have to remove this case if we fix SPARK-6026 and remove this bypass. false } else if (serializeMapOutputs && serializer.supportsRelocationOfSerializedObjects) { @@ -105,9 +108,14 @@ case class Exchange( // them. This optimization is guarded by a feature-flag and is only applied in cases where // shuffle dependency does not specify an ordering and the record serializer has certain // properties. If this optimization is enabled, we can safely avoid the copy. + // + // This optimization also applies to UnsafeShuffleManager (added in SPARK-7081). false } else { - // None of the special cases held, so we must copy. + // Spark's SortShuffleManager uses `ExternalSorter` to buffer records in memory. This code + // path is used both when SortShuffleManager is used and when UnsafeShuffleManager falls + // back to SortShuffleManager to perform a shuffle that the new fast path can't handle. In + // both cases, we must copy. true } } else { diff --git a/unsafe/pom.xml b/unsafe/pom.xml index 5b0733206b2bc..9e151fc7a9141 100644 --- a/unsafe/pom.xml +++ b/unsafe/pom.xml @@ -42,6 +42,10 @@ com.google.code.findbugs jsr305 + + com.google.guava + guava + diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java index 9224988e6ad69..2906ac8abad1a 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java @@ -19,6 +19,7 @@ import java.util.*; +import com.google.common.annotations.VisibleForTesting; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -47,10 +48,18 @@ public final class TaskMemoryManager { private final Logger logger = LoggerFactory.getLogger(TaskMemoryManager.class); - /** - * The number of entries in the page table. - */ - private static final int PAGE_TABLE_SIZE = 1 << 13; + /** The number of bits used to address the page table. */ + private static final int PAGE_NUMBER_BITS = 13; + + /** The number of bits used to encode offsets in data pages. */ + @VisibleForTesting + static final int OFFSET_BITS = 64 - PAGE_NUMBER_BITS; // 51 + + /** The number of entries in the page table. */ + private static final int PAGE_TABLE_SIZE = 1 << PAGE_NUMBER_BITS; + + /** Maximum supported data page size */ + private static final long MAXIMUM_PAGE_SIZE = (1L << OFFSET_BITS); /** Bit mask for the lower 51 bits of a long. */ private static final long MASK_LONG_LOWER_51_BITS = 0x7FFFFFFFFFFFFL; @@ -101,11 +110,9 @@ public TaskMemoryManager(ExecutorMemoryManager executorMemoryManager) { * intended for allocating large blocks of memory that will be shared between operators. */ public MemoryBlock allocatePage(long size) { - if (logger.isTraceEnabled()) { - logger.trace("Allocating {} byte page", size); - } - if (size >= (1L << 51)) { - throw new IllegalArgumentException("Cannot allocate a page with more than 2^51 bytes"); + if (size > MAXIMUM_PAGE_SIZE) { + throw new IllegalArgumentException( + "Cannot allocate a page with more than " + MAXIMUM_PAGE_SIZE + " bytes"); } final int pageNumber; @@ -120,8 +127,8 @@ public MemoryBlock allocatePage(long size) { final MemoryBlock page = executorMemoryManager.allocate(size); page.pageNumber = pageNumber; pageTable[pageNumber] = page; - if (logger.isDebugEnabled()) { - logger.debug("Allocate page number {} ({} bytes)", pageNumber, size); + if (logger.isTraceEnabled()) { + logger.trace("Allocate page number {} ({} bytes)", pageNumber, size); } return page; } @@ -130,9 +137,6 @@ public MemoryBlock allocatePage(long size) { * Free a block of memory allocated via {@link TaskMemoryManager#allocatePage(long)}. */ public void freePage(MemoryBlock page) { - if (logger.isTraceEnabled()) { - logger.trace("Freeing page number {} ({} bytes)", page.pageNumber, page.size()); - } assert (page.pageNumber != -1) : "Called freePage() on memory that wasn't allocated with allocatePage()"; executorMemoryManager.free(page); @@ -140,8 +144,8 @@ public void freePage(MemoryBlock page) { allocatedPages.clear(page.pageNumber); } pageTable[page.pageNumber] = null; - if (logger.isDebugEnabled()) { - logger.debug("Freed page number {} ({} bytes)", page.pageNumber, page.size()); + if (logger.isTraceEnabled()) { + logger.trace("Freed page number {} ({} bytes)", page.pageNumber, page.size()); } } @@ -173,14 +177,36 @@ public void free(MemoryBlock memory) { /** * Given a memory page and offset within that page, encode this address into a 64-bit long. * This address will remain valid as long as the corresponding page has not been freed. + * + * @param page a data page allocated by {@link TaskMemoryManager#allocate(long)}. + * @param offsetInPage an offset in this page which incorporates the base offset. In other words, + * this should be the value that you would pass as the base offset into an + * UNSAFE call (e.g. page.baseOffset() + something). + * @return an encoded page address. */ public long encodePageNumberAndOffset(MemoryBlock page, long offsetInPage) { - if (inHeap) { - assert (page.pageNumber != -1) : "encodePageNumberAndOffset called with invalid page"; - return (((long) page.pageNumber) << 51) | (offsetInPage & MASK_LONG_LOWER_51_BITS); - } else { - return offsetInPage; + if (!inHeap) { + // In off-heap mode, an offset is an absolute address that may require a full 64 bits to + // encode. Due to our page size limitation, though, we can convert this into an offset that's + // relative to the page's base offset; this relative offset will fit in 51 bits. + offsetInPage -= page.getBaseOffset(); } + return encodePageNumberAndOffset(page.pageNumber, offsetInPage); + } + + @VisibleForTesting + public static long encodePageNumberAndOffset(int pageNumber, long offsetInPage) { + assert (pageNumber != -1) : "encodePageNumberAndOffset called with invalid page"; + return (((long) pageNumber) << OFFSET_BITS) | (offsetInPage & MASK_LONG_LOWER_51_BITS); + } + + @VisibleForTesting + public static int decodePageNumber(long pagePlusOffsetAddress) { + return (int) ((pagePlusOffsetAddress & MASK_LONG_UPPER_13_BITS) >>> OFFSET_BITS); + } + + private static long decodeOffset(long pagePlusOffsetAddress) { + return (pagePlusOffsetAddress & MASK_LONG_LOWER_51_BITS); } /** @@ -189,7 +215,7 @@ public long encodePageNumberAndOffset(MemoryBlock page, long offsetInPage) { */ public Object getPage(long pagePlusOffsetAddress) { if (inHeap) { - final int pageNumber = (int) ((pagePlusOffsetAddress & MASK_LONG_UPPER_13_BITS) >>> 51); + final int pageNumber = decodePageNumber(pagePlusOffsetAddress); assert (pageNumber >= 0 && pageNumber < PAGE_TABLE_SIZE); final Object page = pageTable[pageNumber].getBaseObject(); assert (page != null); @@ -204,10 +230,15 @@ public Object getPage(long pagePlusOffsetAddress) { * {@link TaskMemoryManager#encodePageNumberAndOffset(MemoryBlock, long)} */ public long getOffsetInPage(long pagePlusOffsetAddress) { + final long offsetInPage = decodeOffset(pagePlusOffsetAddress); if (inHeap) { - return (pagePlusOffsetAddress & MASK_LONG_LOWER_51_BITS); + return offsetInPage; } else { - return pagePlusOffsetAddress; + // In off-heap mode, an offset is an absolute address. In encodePageNumberAndOffset, we + // converted the absolute address into a relative address. Here, we invert that operation: + final int pageNumber = decodePageNumber(pagePlusOffsetAddress); + assert (pageNumber >= 0 && pageNumber < PAGE_TABLE_SIZE); + return pageTable[pageNumber].getBaseOffset() + offsetInPage; } } diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java index 932882f1ca248..06fb081183659 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java @@ -38,4 +38,27 @@ public void leakedPageMemoryIsDetected() { Assert.assertEquals(4096, manager.cleanUpAllAllocatedMemory()); } + @Test + public void encodePageNumberAndOffsetOffHeap() { + final TaskMemoryManager manager = + new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.UNSAFE)); + final MemoryBlock dataPage = manager.allocatePage(256); + // In off-heap mode, an offset is an absolute address that may require more than 51 bits to + // encode. This test exercises that corner-case: + final long offset = ((1L << TaskMemoryManager.OFFSET_BITS) + 10); + final long encodedAddress = manager.encodePageNumberAndOffset(dataPage, offset); + Assert.assertEquals(null, manager.getPage(encodedAddress)); + Assert.assertEquals(offset, manager.getOffsetInPage(encodedAddress)); + } + + @Test + public void encodePageNumberAndOffsetOnHeap() { + final TaskMemoryManager manager = + new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); + final MemoryBlock dataPage = manager.allocatePage(256); + final long encodedAddress = manager.encodePageNumberAndOffset(dataPage, 64); + Assert.assertEquals(dataPage.getBaseObject(), manager.getPage(encodedAddress)); + Assert.assertEquals(64, manager.getOffsetInPage(encodedAddress)); + } + }