From 8e3ec208be6cdf4eb167bdbe6940ef5552aeb58a Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 1 May 2015 13:31:28 -0700 Subject: [PATCH] Begin code cleanup. --- .../shuffle/unsafe/UnsafeShuffleManager.scala | 149 ++++++++++-------- 1 file changed, 84 insertions(+), 65 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala index 56d76304e713a..0e796dfe2aefd 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala @@ -17,23 +17,23 @@ package org.apache.spark.shuffle.unsafe -import java.io.{ByteArrayOutputStream, FileOutputStream} +import java.io.{FileOutputStream, OutputStream} import java.nio.ByteBuffer import java.util import com.esotericsoftware.kryo.io.ByteBufferOutputStream + +import org.apache.spark.{ShuffleDependency, SparkConf, SparkEnv, TaskContext} import org.apache.spark.executor.ShuffleWriteMetrics -import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.scheduler.MapStatus import org.apache.spark.serializer.Serializer +import org.apache.spark.shuffle._ import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.storage.ShuffleBlockId import org.apache.spark.unsafe.PlatformDependent -import org.apache.spark.unsafe.memory.{TaskMemoryManager, MemoryBlock} +import org.apache.spark.unsafe.memory.{MemoryBlock, TaskMemoryManager} import org.apache.spark.unsafe.sort.UnsafeSorter import org.apache.spark.unsafe.sort.UnsafeSorter.{KeyPointerAndPrefix, PrefixComparator, PrefixComputer, RecordComparator} -import org.apache.spark.{SparkEnv, SparkConf, ShuffleDependency, TaskContext} -import org.apache.spark.shuffle._ private[spark] class UnsafeShuffleHandle[K, V]( shuffleId: Int, @@ -87,7 +87,7 @@ private[spark] class UnsafeShuffleWriter[K, V]( private[this] val dep = handle.dependency - private[this] var sorter: UnsafeSorter = null + private[this] val partitioner = dep.partitioner // Are we in the process of stopping? Because map tasks can call stop() with success = true // and then call stop() with success = false if they get an exception, we want to make sure @@ -104,52 +104,55 @@ private[spark] class UnsafeShuffleWriter[K, V]( private[this] val blockManager = SparkEnv.get.blockManager - /** Write a sequence of records to this task's output */ - override def write(records: Iterator[_ <: Product2[K, V]]): Unit = { - println("Opened writer!") - val serializer = Serializer.getSerializer(dep.serializer).newInstance() - val partitioner = dep.partitioner - sorter = new UnsafeSorter( + private def sortRecords(records: Iterator[_ <: Product2[K, V]]): java.util.Iterator[KeyPointerAndPrefix] = { + val sorter = new UnsafeSorter( context.taskMemoryManager(), DummyRecordComparator, PartitionerPrefixComputer, PartitionerPrefixComparator, 4096 // initial size ) - - // Pack records into data pages: + val serializer = Serializer.getSerializer(dep.serializer).newInstance() val PAGE_SIZE = 1024 * 1024 * 1 + var currentPage: MemoryBlock = memoryManager.allocatePage(PAGE_SIZE) - allocatedPages.add(currentPage) var currentPagePosition: Long = currentPage.getBaseOffset - // TODO make this configurable + def ensureSpaceInDataPage(spaceRequired: Long): Unit = { + if (spaceRequired > PAGE_SIZE) { + throw new Exception(s"Size requirement $spaceRequired is greater than page size $PAGE_SIZE") + } else if (spaceRequired > (PAGE_SIZE - currentPagePosition)) { + currentPage = memoryManager.allocatePage(PAGE_SIZE) + allocatedPages.add(currentPage) + currentPagePosition = currentPage.getBaseOffset + } + } + + // TODO: the size of this buffer should be configurable val serArray = new Array[Byte](1024 * 1024) val byteBuffer = ByteBuffer.wrap(serArray) val bbos = new ByteBufferOutputStream() bbos.setByteBuffer(byteBuffer) val serBufferSerStream = serializer.serializeStream(bbos) - while (records.hasNext) { - val nextRecord: Product2[K, V] = records.next() - println("Writing record " + nextRecord) - val partitionId: Int = partitioner.getPartition(nextRecord._1) - serBufferSerStream.writeObject(nextRecord) - - val sizeRequirement: Int = byteBuffer.position() + 8 + 8 - println("Size requirement in intenral buffer is " + sizeRequirement) - if (sizeRequirement > (PAGE_SIZE - currentPagePosition)) { - println("Allocating a new data page after writing " + currentPagePosition) - currentPage = memoryManager.allocatePage(PAGE_SIZE) - allocatedPages.add(currentPage) - currentPagePosition = currentPage.getBaseOffset - } - println("Before writing record, current page position is " + currentPagePosition) - // TODO: check that it's still not too large + def writeRecord(record: Product2[Any, Any]): Unit = { + val (key, value) = record + val partitionId = partitioner.getPartition(key) + serBufferSerStream.writeKey(key) + serBufferSerStream.writeValue(value) + serBufferSerStream.flush() + + val serializedRecordSize = byteBuffer.position() + // TODO: we should run the partition extraction function _now_, at insert time, rather than + // requiring it to be stored alongisde the data, since this may lead to double storage + val sizeRequirementInSortDataPage = serializedRecordSize + 8 + 8 + ensureSpaceInDataPage(sizeRequirementInSortDataPage) + val newRecordAddress = memoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition) PlatformDependent.UNSAFE.putLong(currentPage.getBaseObject, currentPagePosition, partitionId) currentPagePosition += 8 + println("The stored record length is " + byteBuffer.position()) PlatformDependent.UNSAFE.putLong( currentPage.getBaseObject, currentPagePosition, byteBuffer.position()) currentPagePosition += 8 @@ -162,45 +165,53 @@ private[spark] class UnsafeShuffleWriter[K, V]( currentPagePosition += byteBuffer.position() println("After writing record, current page position is " + currentPagePosition) sorter.insertRecord(newRecordAddress) + + // Reset for writing the next record byteBuffer.position(0) } - // TODO: free the buffers, etc, at this point since they're not needed - val sortedIterator: util.Iterator[KeyPointerAndPrefix] = sorter.getSortedIterator - // Now that the partition is sorted, write out the data to a file, keeping track off offsets - // for use in the sort-based shuffle index. + + while (records.hasNext) { + writeRecord(records.next()) + } + + sorter.getSortedIterator + } + + private def writeSortedRecordsToFile(sortedRecords: java.util.Iterator[KeyPointerAndPrefix]): Array[Long] = { val outputFile = shuffleBlockManager.getDataFile(dep.shuffleId, mapId) val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockManager.NOOP_REDUCE_ID) val partitionLengths = new Array[Long](partitioner.numPartitions) - // TODO: compression tests? - // TODO why is append true here? - // TODO: metrics tracking and all of the other stuff that diskblockobjectwriter would give us - // TODO: note that we saw FAILED_TO_UNCOMPRESS(5) at some points during debugging when we were - // not properly wrapping the writer for compression even though readers expected compressed - // data; the fact that someone still reported this issue in newer Spark versions suggests that - // we should audit the code to make sure wrapping is done at the right set of places and to - // check that we haven't missed any rare corner-cases / rarely-used paths. - val out = blockManager.wrapForCompression(blockId, new FileOutputStream(outputFile, true)) - val serOut = serializer.serializeStream(out) - serOut.flush() + var currentPartition = -1 - var currentPartitionLength: Long = 0 - while (sortedIterator.hasNext) { - val keyPointerAndPrefix: KeyPointerAndPrefix = sortedIterator.next() - val partition = keyPointerAndPrefix.keyPrefix.toInt - println("Partition is " + partition) - if (currentPartition == -1) { - currentPartition = partition + var prevPartitionLength: Long = 0 + var out: OutputStream = null + + // TODO: don't close and re-open file handles so often; this could be inefficient + + def closePartition(): Unit = { + out.flush() + out.close() + partitionLengths(currentPartition) = outputFile.length() - prevPartitionLength + } + + def switchToPartition(newPartition: Int): Unit = { + if (currentPartition != -1) { + closePartition() + prevPartitionLength = partitionLengths(currentPartition) } + currentPartition = newPartition + out = blockManager.wrapForCompression(blockId, new FileOutputStream(outputFile, true)) + } + + while (sortedRecords.hasNext) { + val keyPointerAndPrefix: KeyPointerAndPrefix = sortedRecords.next() + val partition = keyPointerAndPrefix.keyPrefix.toInt if (partition != currentPartition) { - println("switching partition") - partitionLengths(currentPartition) = currentPartitionLength - currentPartitionLength = 0 - currentPartition = partition + switchToPartition(partition) } val baseObject = memoryManager.getPage(keyPointerAndPrefix.recordPointer) val baseOffset = memoryManager.getOffsetInPage(keyPointerAndPrefix.recordPointer) val recordLength = PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + 8) - partitionLengths(currentPartition) += recordLength println("Base offset is " + baseOffset) println("Record length is " + recordLength) var i: Int = 0 @@ -213,10 +224,19 @@ private[spark] class UnsafeShuffleWriter[K, V]( i += 1 } } - out.flush() - //serOut.close() - //out.flush() - out.close() + closePartition() + + partitionLengths + } + + /** Write a sequence of records to this task's output */ + override def write(records: Iterator[_ <: Product2[K, V]]): Unit = { + println("Opened writer!") + + val sortedIterator = sortRecords(records) + val partitionLengths = writeSortedRecordsToFile(sortedIterator) + + println("Partition lengths are " + partitionLengths.toSeq) shuffleBlockManager.writeIndexFile(dep.shuffleId, mapId, partitionLengths) mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths) } @@ -239,7 +259,7 @@ private[spark] class UnsafeShuffleWriter[K, V]( } } finally { // Clean up our sorter, which may have its own intermediate files - if (sorter != null) { + if (!allocatedPages.isEmpty) { val iter = allocatedPages.iterator() while (iter.hasNext) { memoryManager.freePage(iter.next()) @@ -249,7 +269,6 @@ private[spark] class UnsafeShuffleWriter[K, V]( //sorter.stop() context.taskMetrics().shuffleWriteMetrics.foreach( _.incShuffleWriteTime(System.nanoTime - startTime)) - sorter = null } } }