From 253f13ee0796aa724decd075d159b81eda459daf Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 1 May 2015 13:55:35 -0700 Subject: [PATCH] More cleanup --- .../shuffle/unsafe/UnsafeShuffleManager.scala | 44 ++++++++++--------- 1 file changed, 23 insertions(+), 21 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala index 0e796dfe2aefd..04ac811ac7966 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala @@ -115,8 +115,8 @@ private[spark] class UnsafeShuffleWriter[K, V]( val serializer = Serializer.getSerializer(dep.serializer).newInstance() val PAGE_SIZE = 1024 * 1024 * 1 - var currentPage: MemoryBlock = memoryManager.allocatePage(PAGE_SIZE) - var currentPagePosition: Long = currentPage.getBaseOffset + var currentPage: MemoryBlock = null + var currentPagePosition: Long = PAGE_SIZE def ensureSpaceInDataPage(spaceRequired: Long): Unit = { if (spaceRequired > PAGE_SIZE) { @@ -143,6 +143,7 @@ private[spark] class UnsafeShuffleWriter[K, V]( serBufferSerStream.flush() val serializedRecordSize = byteBuffer.position() + assert(serializedRecordSize > 0) // TODO: we should run the partition extraction function _now_, at insert time, rather than // requiring it to be stored alongisde the data, since this may lead to double storage val sizeRequirementInSortDataPage = serializedRecordSize + 8 + 8 @@ -152,17 +153,17 @@ private[spark] class UnsafeShuffleWriter[K, V]( memoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition) PlatformDependent.UNSAFE.putLong(currentPage.getBaseObject, currentPagePosition, partitionId) currentPagePosition += 8 - println("The stored record length is " + byteBuffer.position()) + println("The stored record length is " + serializedRecordSize) PlatformDependent.UNSAFE.putLong( - currentPage.getBaseObject, currentPagePosition, byteBuffer.position()) + currentPage.getBaseObject, currentPagePosition, serializedRecordSize) currentPagePosition += 8 PlatformDependent.copyMemory( serArray, PlatformDependent.BYTE_ARRAY_OFFSET, currentPage.getBaseObject, currentPagePosition, - byteBuffer.position()) - currentPagePosition += byteBuffer.position() + serializedRecordSize) + currentPagePosition += serializedRecordSize println("After writing record, current page position is " + currentPagePosition) sorter.insertRecord(newRecordAddress) @@ -195,10 +196,12 @@ private[spark] class UnsafeShuffleWriter[K, V]( } def switchToPartition(newPartition: Int): Unit = { + assert (newPartition > currentPartition, s"new partition $newPartition should be >= $currentPartition") if (currentPartition != -1) { closePartition() prevPartitionLength = partitionLengths(currentPartition) } + println(s"Before switching to partition $newPartition, partition lengths are " + partitionLengths.toSeq) currentPartition = newPartition out = blockManager.wrapForCompression(blockId, new FileOutputStream(outputFile, true)) } @@ -214,11 +217,11 @@ private[spark] class UnsafeShuffleWriter[K, V]( val recordLength = PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + 8) println("Base offset is " + baseOffset) println("Record length is " + recordLength) - var i: Int = 0 // TODO: need to have a way to figure out whether a serializer supports relocation of // serialized objects or not. Sandy also ran into this in his patch (see // https://github.com/apache/spark/pull/4450). If we're using Java serialization, we might // as well just bypass this optimized code path in favor of the old one. + var i: Int = 0 while (i < recordLength) { out.write(PlatformDependent.UNSAFE.getByte(baseObject, baseOffset + 16 + i)) i += 1 @@ -241,6 +244,14 @@ private[spark] class UnsafeShuffleWriter[K, V]( mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths) } + private def freeMemory(): Unit = { + val iter = allocatedPages.iterator() + while (iter.hasNext) { + memoryManager.freePage(iter.next()) + iter.remove() + } + } + /** Close this writer, passing along whether the map completed */ override def stop(success: Boolean): Option[MapStatus] = { println("Stopping unsafeshufflewriter") @@ -249,6 +260,7 @@ private[spark] class UnsafeShuffleWriter[K, V]( None } else { stopping = true + freeMemory() if (success) { Option(mapStatus) } else { @@ -258,24 +270,14 @@ private[spark] class UnsafeShuffleWriter[K, V]( } } } finally { - // Clean up our sorter, which may have its own intermediate files - if (!allocatedPages.isEmpty) { - val iter = allocatedPages.iterator() - while (iter.hasNext) { - memoryManager.freePage(iter.next()) - iter.remove() - } - val startTime = System.nanoTime() - //sorter.stop() - context.taskMetrics().shuffleWriteMetrics.foreach( - _.incShuffleWriteTime(System.nanoTime - startTime)) - } + freeMemory() + val startTime = System.nanoTime() + context.taskMetrics().shuffleWriteMetrics.foreach( + _.incShuffleWriteTime(System.nanoTime - startTime)) } } } - - private[spark] class UnsafeShuffleManager(conf: SparkConf) extends ShuffleManager { private[this] val sortShuffleManager: SortShuffleManager = new SortShuffleManager(conf)