From 07484589af43de03343fe58601793f6aaff33d56 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 1 May 2015 17:50:48 -0700 Subject: [PATCH] Port UnsafeShuffleWriter to Java. --- .../shuffle/unsafe/UnsafeShuffleWriter.java | 282 ++++++++++++++++++ .../shuffle/unsafe/UnsafeShuffleManager.scala | 243 --------------- 2 files changed, 282 insertions(+), 243 deletions(-) create mode 100644 core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java new file mode 100644 index 0000000000000..bf368d4a11526 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java @@ -0,0 +1,282 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.unsafe; + +import scala.Option; +import scala.Product2; +import scala.reflect.ClassTag; +import scala.reflect.ClassTag$; + +import java.io.File; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Iterator; +import java.util.LinkedList; + +import com.esotericsoftware.kryo.io.ByteBufferOutputStream; + +import org.apache.spark.Partitioner; +import org.apache.spark.ShuffleDependency; +import org.apache.spark.SparkEnv; +import org.apache.spark.TaskContext; +import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.scheduler.MapStatus; +import org.apache.spark.scheduler.MapStatus$; +import org.apache.spark.serializer.SerializationStream; +import org.apache.spark.serializer.Serializer; +import org.apache.spark.serializer.SerializerInstance; +import org.apache.spark.shuffle.IndexShuffleBlockManager; +import org.apache.spark.shuffle.ShuffleWriter; +import org.apache.spark.storage.BlockManager; +import org.apache.spark.storage.BlockObjectWriter; +import org.apache.spark.storage.ShuffleBlockId; +import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.memory.TaskMemoryManager; +import org.apache.spark.unsafe.sort.UnsafeSorter; +import static org.apache.spark.unsafe.sort.UnsafeSorter.*; + +// IntelliJ gets confused and claims that this class should be abstract, but this actually compiles +public class UnsafeShuffleWriter implements ShuffleWriter { + + private static final int PAGE_SIZE = 1024 * 1024; // TODO: tune this + private static final int SER_BUFFER_SIZE = 1024 * 1024; // TODO: tune this + private static final ClassTag OBJECT_CLASS_TAG = ClassTag$.MODULE$.Object(); + + private final IndexShuffleBlockManager shuffleBlockManager; + private final BlockManager blockManager = SparkEnv.get().blockManager(); + private final int shuffleId; + private final int mapId; + private final TaskMemoryManager memoryManager; + private final SerializerInstance serializer; + private final Partitioner partitioner; + private final ShuffleWriteMetrics writeMetrics; + private final LinkedList allocatedPages = new LinkedList(); + private final int fileBufferSize; + private MapStatus mapStatus = null; + + private MemoryBlock currentPage = null; + private long currentPagePosition = PAGE_SIZE; + + /** + * Are we in the process of stopping? Because map tasks can call stop() with success = true + * and then call stop() with success = false if they get an exception, we want to make sure + * we don't try deleting files, etc twice. + */ + private boolean stopping = false; + + public UnsafeShuffleWriter( + IndexShuffleBlockManager shuffleBlockManager, + UnsafeShuffleHandle handle, + int mapId, + TaskContext context) { + this.shuffleBlockManager = shuffleBlockManager; + this.mapId = mapId; + this.memoryManager = context.taskMemoryManager(); + final ShuffleDependency dep = handle.dependency(); + this.shuffleId = dep.shuffleId(); + this.serializer = Serializer.getSerializer(dep.serializer()).newInstance(); + this.partitioner = dep.partitioner(); + this.writeMetrics = new ShuffleWriteMetrics(); + context.taskMetrics().shuffleWriteMetrics_$eq(Option.apply(writeMetrics)); + this.fileBufferSize = + // Use getSizeAsKb (not bytes) to maintain backwards compatibility for units + (int) SparkEnv.get().conf().getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; + } + + public void write(scala.collection.Iterator> records) { + try { + final long[] partitionLengths = writeSortedRecordsToFile(sortRecords(records)); + shuffleBlockManager.writeIndexFile(shuffleId, mapId, partitionLengths); + mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); + } catch (Exception e) { + PlatformDependent.throwException(e); + } + } + + private void ensureSpaceInDataPage(long requiredSpace) throws Exception { + if (requiredSpace > PAGE_SIZE) { + // TODO: throw a more specific exception? + throw new Exception("Required space " + requiredSpace + " is greater than page size (" + + PAGE_SIZE + ")"); + } else if (requiredSpace > (PAGE_SIZE - currentPagePosition)) { + currentPage = memoryManager.allocatePage(PAGE_SIZE); + currentPagePosition = currentPage.getBaseOffset(); + allocatedPages.add(currentPage); + } + } + + private void freeMemory() { + final Iterator iter = allocatedPages.iterator(); + while (iter.hasNext()) { + memoryManager.freePage(iter.next()); + iter.remove(); + } + } + + private Iterator sortRecords( + scala.collection.Iterator> records) throws Exception { + final UnsafeSorter sorter = new UnsafeSorter( + memoryManager, + RECORD_COMPARATOR, + PREFIX_COMPUTER, + PREFIX_COMPARATOR, + 4096 // Initial size (TODO: tune this!) + ); + + final byte[] serArray = new byte[SER_BUFFER_SIZE]; + final ByteBuffer serByteBuffer = ByteBuffer.wrap(serArray); + // TODO: we should not depend on this class from Kryo; copy its source or find an alternative + final SerializationStream serOutputStream = + serializer.serializeStream(new ByteBufferOutputStream(serByteBuffer)); + + while (records.hasNext()) { + final Product2 record = records.next(); + final K key = record._1(); + final int partitionId = partitioner.getPartition(key); + serByteBuffer.position(0); + serOutputStream.writeKey(key, OBJECT_CLASS_TAG); + serOutputStream.writeValue(record._2(), OBJECT_CLASS_TAG); + serOutputStream.flush(); + + final int serializedRecordSize = serByteBuffer.position(); + assert (serializedRecordSize > 0); + // TODO: we should run the partition extraction function _now_, at insert time, rather than + // requiring it to be stored alongisde the data, since this may lead to double storage + // Need 8 bytes to store the prefix (for later retrieval in the prefix computer), plus + // 8 to store the record length (TODO: can store as an int instead). + ensureSpaceInDataPage(serializedRecordSize + 8 + 8); + + final long recordAddress = + memoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition); + final Object baseObject = currentPage.getBaseObject(); + PlatformDependent.UNSAFE.putLong(baseObject, currentPagePosition, partitionId); + currentPagePosition += 8; + PlatformDependent.UNSAFE.putLong(baseObject, currentPagePosition, serializedRecordSize); + currentPagePosition += 8; + PlatformDependent.copyMemory( + serArray, + PlatformDependent.BYTE_ARRAY_OFFSET, + baseObject, + currentPagePosition, + serializedRecordSize); + currentPagePosition += serializedRecordSize; + + sorter.insertRecord(recordAddress); + } + + return sorter.getSortedIterator(); + } + + private long[] writeSortedRecordsToFile( + Iterator sortedRecords) throws IOException { + final File outputFile = shuffleBlockManager.getDataFile(shuffleId, mapId); + final ShuffleBlockId blockId = + new ShuffleBlockId(shuffleId, mapId, IndexShuffleBlockManager.NOOP_REDUCE_ID()); + final long[] partitionLengths = new long[partitioner.numPartitions()]; + + int currentPartition = -1; + BlockObjectWriter writer = null; + + while (sortedRecords.hasNext()) { + final RecordPointerAndKeyPrefix recordPointer = sortedRecords.next(); + final int partition = (int) recordPointer.keyPrefix; + assert (partition >= currentPartition); + if (partition != currentPartition) { + // Switch to the new partition + if (currentPartition != -1) { + writer.commitAndClose(); + partitionLengths[currentPartition] = writer.fileSegment().length(); + } + currentPartition = partition; + writer = + blockManager.getDiskWriter(blockId, outputFile, serializer, fileBufferSize, writeMetrics); + } + + final Object baseObject = memoryManager.getPage(recordPointer.recordPointer); + final long baseOffset = memoryManager.getOffsetInPage(recordPointer.recordPointer); + final int recordLength = (int) PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + 8); + // TODO: re-use a buffer or avoid double-buffering entirely + final byte[] arr = new byte[recordLength]; + PlatformDependent.copyMemory( + baseObject, + baseOffset + 16, + arr, + PlatformDependent.BYTE_ARRAY_OFFSET, + recordLength); + assert (writer != null); // To suppress an IntelliJ warning + writer.write(arr); + // TODO: add a test that detects whether we leave this call out: + writer.recordWritten(); + } + + if (writer != null) { + writer.commitAndClose(); + partitionLengths[currentPartition] = writer.fileSegment().length(); + } + + return partitionLengths; + } + + @Override + public Option stop(boolean success) { + try { + if (stopping) { + return Option.apply(null); + } else { + stopping = true; + freeMemory(); + if (success) { + return Option.apply(mapStatus); + } else { + // The map task failed, so delete our output data. + shuffleBlockManager.removeDataByMap(shuffleId, mapId); + return Option.apply(null); + } + } + } finally { + freeMemory(); + // TODO: increment the shuffle write time metrics + } + } + + private static final RecordComparator RECORD_COMPARATOR = new RecordComparator() { + @Override + public int compare( + Object leftBaseObject, long leftBaseOffset, Object rightBaseObject, long rightBaseOffset) { + return 0; + } + }; + + private static final PrefixComputer PREFIX_COMPUTER = new PrefixComputer() { + @Override + public long computePrefix(Object baseObject, long baseOffset) { + // TODO: should the prefix be computed when inserting the record pointer rather than being + // read from the record itself? May be more efficient in terms of space, etc, and is a simple + // change. + return PlatformDependent.UNSAFE.getLong(baseObject, baseOffset); + } + }; + + private static final PrefixComparator PREFIX_COMPARATOR = new PrefixComparator() { + @Override + public int compare(long prefix1, long prefix2) { + return (int) (prefix1 - prefix2); + } + }; +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala index 7eb825641263b..0dd34b372f624 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala @@ -17,22 +17,10 @@ package org.apache.spark.shuffle.unsafe -import java.nio.ByteBuffer -import java.util - -import com.esotericsoftware.kryo.io.ByteBufferOutputStream - import org.apache.spark._ -import org.apache.spark.executor.ShuffleWriteMetrics -import org.apache.spark.scheduler.MapStatus import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle._ import org.apache.spark.shuffle.sort.SortShuffleManager -import org.apache.spark.storage.{BlockObjectWriter, ShuffleBlockId} -import org.apache.spark.unsafe.PlatformDependent -import org.apache.spark.unsafe.memory.{MemoryBlock, TaskMemoryManager} -import org.apache.spark.unsafe.sort.UnsafeSorter -import org.apache.spark.unsafe.sort.UnsafeSorter.{RecordPointerAndKeyPrefix, PrefixComparator, PrefixComputer, RecordComparator} private class UnsafeShuffleHandle[K, V]( shuffleId: Int, @@ -62,237 +50,6 @@ private[spark] object UnsafeShuffleManager extends Logging { } } -private object DummyRecordComparator extends RecordComparator { - override def compare( - leftBaseObject: scala.Any, - leftBaseOffset: Long, - rightBaseObject: scala.Any, - rightBaseOffset: Long): Int = { - 0 - } -} - -private object PartitionerPrefixComputer extends PrefixComputer { - override def computePrefix(baseObject: scala.Any, baseOffset: Long): Long = { - // TODO: should the prefix be computed when inserting the record pointer rather than being - // read from the record itself? May be more efficient in terms of space, etc, and is a simple - // change. - PlatformDependent.UNSAFE.getLong(baseObject, baseOffset) - } -} - -private object PartitionerPrefixComparator extends PrefixComparator { - override def compare(prefix1: Long, prefix2: Long): Int = { - (prefix1 - prefix2).toInt - } -} - -private class UnsafeShuffleWriter[K, V]( - shuffleBlockManager: IndexShuffleBlockManager, - handle: UnsafeShuffleHandle[K, V], - mapId: Int, - context: TaskContext) - extends ShuffleWriter[K, V] { - - private[this] val memoryManager: TaskMemoryManager = context.taskMemoryManager() - - private[this] val dep = handle.dependency - - private[this] val partitioner = dep.partitioner - - // Are we in the process of stopping? Because map tasks can call stop() with success = true - // and then call stop() with success = false if they get an exception, we want to make sure - // we don't try deleting files, etc twice. - private[this] var stopping = false - - private[this] var mapStatus: MapStatus = null - - private[this] val writeMetrics = new ShuffleWriteMetrics() - context.taskMetrics().shuffleWriteMetrics = Some(writeMetrics) - - private[this] val allocatedPages: util.LinkedList[MemoryBlock] = - new util.LinkedList[MemoryBlock]() - - private[this] val blockManager = SparkEnv.get.blockManager - - // Use getSizeAsKb (not bytes) to maintain backwards compatibility of on units are provided - private[this] val fileBufferSize = - SparkEnv.get.conf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024 - - private[this] val serializer = Serializer.getSerializer(dep.serializer).newInstance() - - private def sortRecords( - records: Iterator[_ <: Product2[K, V]]): java.util.Iterator[RecordPointerAndKeyPrefix] = { - val sorter = new UnsafeSorter( - context.taskMemoryManager(), - DummyRecordComparator, - PartitionerPrefixComputer, - PartitionerPrefixComparator, - 4096 // initial size - ) - val PAGE_SIZE = 1024 * 1024 * 1 - - var currentPage: MemoryBlock = null - var currentPagePosition: Long = PAGE_SIZE - - def ensureSpaceInDataPage(spaceRequired: Long): Unit = { - if (spaceRequired > PAGE_SIZE) { - throw new Exception(s"Size requirement $spaceRequired is greater than page size $PAGE_SIZE") - } else if (spaceRequired > (PAGE_SIZE - currentPagePosition)) { - currentPage = memoryManager.allocatePage(PAGE_SIZE) - allocatedPages.add(currentPage) - currentPagePosition = currentPage.getBaseOffset - } - } - - // TODO: the size of this buffer should be configurable - val serArray = new Array[Byte](1024 * 1024) - val byteBuffer = ByteBuffer.wrap(serArray) - val bbos = new ByteBufferOutputStream() - bbos.setByteBuffer(byteBuffer) - val serBufferSerStream = serializer.serializeStream(bbos) - - def writeRecord(record: Product2[Any, Any]): Unit = { - val (key, value) = record - val partitionId = partitioner.getPartition(key) - serBufferSerStream.writeKey(key) - serBufferSerStream.writeValue(value) - serBufferSerStream.flush() - - val serializedRecordSize = byteBuffer.position() - assert(serializedRecordSize > 0) - // TODO: we should run the partition extraction function _now_, at insert time, rather than - // requiring it to be stored alongisde the data, since this may lead to double storage - val sizeRequirementInSortDataPage = serializedRecordSize + 8 + 8 - ensureSpaceInDataPage(sizeRequirementInSortDataPage) - - val newRecordAddress = - memoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition) - PlatformDependent.UNSAFE.putLong(currentPage.getBaseObject, currentPagePosition, partitionId) - currentPagePosition += 8 - PlatformDependent.UNSAFE.putLong( - currentPage.getBaseObject, currentPagePosition, serializedRecordSize) - currentPagePosition += 8 - PlatformDependent.copyMemory( - serArray, - PlatformDependent.BYTE_ARRAY_OFFSET, - currentPage.getBaseObject, - currentPagePosition, - serializedRecordSize) - currentPagePosition += serializedRecordSize - sorter.insertRecord(newRecordAddress) - - // Reset for writing the next record - byteBuffer.position(0) - } - - while (records.hasNext) { - writeRecord(records.next()) - } - - sorter.getSortedIterator - } - - private def writeSortedRecordsToFile( - sortedRecords: java.util.Iterator[RecordPointerAndKeyPrefix]): Array[Long] = { - val outputFile = shuffleBlockManager.getDataFile(dep.shuffleId, mapId) - val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockManager.NOOP_REDUCE_ID) - val partitionLengths = new Array[Long](partitioner.numPartitions) - - var currentPartition = -1 - var writer: BlockObjectWriter = null - - // TODO: don't close and re-open file handles so often; this could be inefficient - - def closePartition(): Unit = { - if (writer != null) { - writer.commitAndClose() - partitionLengths(currentPartition) = writer.fileSegment().length - } - } - - def switchToPartition(newPartition: Int): Unit = { - assert (newPartition > currentPartition, - s"new partition $newPartition should be >= $currentPartition") - if (currentPartition != -1) { - closePartition() - } - currentPartition = newPartition - writer = - blockManager.getDiskWriter(blockId, outputFile, serializer, fileBufferSize, writeMetrics) - } - - while (sortedRecords.hasNext) { - val keyPointerAndPrefix: RecordPointerAndKeyPrefix = sortedRecords.next() - val partition = keyPointerAndPrefix.keyPrefix.toInt - if (partition != currentPartition) { - switchToPartition(partition) - } - val baseObject = memoryManager.getPage(keyPointerAndPrefix.recordPointer) - val baseOffset = memoryManager.getOffsetInPage(keyPointerAndPrefix.recordPointer) - val recordLength: Int = PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + 8).toInt - // TODO: need to have a way to figure out whether a serializer supports relocation of - // serialized objects or not. Sandy also ran into this in his patch (see - // https://github.com/apache/spark/pull/4450). If we're using Java serialization, we might - // as well just bypass this optimized code path in favor of the old one. - // TODO: re-use a buffer or avoid double-buffering entirely - val arr: Array[Byte] = new Array[Byte](recordLength) - PlatformDependent.copyMemory( - baseObject, - baseOffset + 16, - arr, - PlatformDependent.BYTE_ARRAY_OFFSET, - recordLength) - writer.write(arr) - // TODO: add a test that detects whether we leave this call out: - writer.recordWritten() - } - closePartition() - - partitionLengths - } - - /** Write a sequence of records to this task's output */ - override def write(records: Iterator[_ <: Product2[K, V]]): Unit = { - val sortedIterator = sortRecords(records) - val partitionLengths = writeSortedRecordsToFile(sortedIterator) - shuffleBlockManager.writeIndexFile(dep.shuffleId, mapId, partitionLengths) - mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths) - } - - private def freeMemory(): Unit = { - val iter = allocatedPages.iterator() - while (iter.hasNext) { - memoryManager.freePage(iter.next()) - iter.remove() - } - } - - /** Close this writer, passing along whether the map completed */ - override def stop(success: Boolean): Option[MapStatus] = { - try { - if (stopping) { - None - } else { - stopping = true - freeMemory() - if (success) { - Option(mapStatus) - } else { - // The map task failed, so delete our output data. - shuffleBlockManager.removeDataByMap(dep.shuffleId, mapId) - None - } - } - } finally { - freeMemory() - val startTime = System.nanoTime() - context.taskMetrics().shuffleWriteMetrics.foreach( - _.incShuffleWriteTime(System.nanoTime - startTime)) - } - } -} - private[spark] class UnsafeShuffleManager(conf: SparkConf) extends ShuffleManager { private[this] val sortShuffleManager: SortShuffleManager = new SortShuffleManager(conf)