Skip to content

Commit

Permalink
Begin code cleanup.
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshRosen committed May 1, 2015
1 parent 4d2f5e1 commit 8e3ec20
Showing 1 changed file with 84 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,23 @@

package org.apache.spark.shuffle.unsafe

import java.io.{ByteArrayOutputStream, FileOutputStream}
import java.io.{FileOutputStream, OutputStream}
import java.nio.ByteBuffer
import java.util

import com.esotericsoftware.kryo.io.ByteBufferOutputStream

import org.apache.spark.{ShuffleDependency, SparkConf, SparkEnv, TaskContext}
import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.network.buffer.ManagedBuffer
import org.apache.spark.scheduler.MapStatus
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle._
import org.apache.spark.shuffle.sort.SortShuffleManager
import org.apache.spark.storage.ShuffleBlockId
import org.apache.spark.unsafe.PlatformDependent
import org.apache.spark.unsafe.memory.{TaskMemoryManager, MemoryBlock}
import org.apache.spark.unsafe.memory.{MemoryBlock, TaskMemoryManager}
import org.apache.spark.unsafe.sort.UnsafeSorter
import org.apache.spark.unsafe.sort.UnsafeSorter.{KeyPointerAndPrefix, PrefixComparator, PrefixComputer, RecordComparator}
import org.apache.spark.{SparkEnv, SparkConf, ShuffleDependency, TaskContext}
import org.apache.spark.shuffle._

private[spark] class UnsafeShuffleHandle[K, V](
shuffleId: Int,
Expand Down Expand Up @@ -87,7 +87,7 @@ private[spark] class UnsafeShuffleWriter[K, V](

private[this] val dep = handle.dependency

private[this] var sorter: UnsafeSorter = null
private[this] val partitioner = dep.partitioner

// Are we in the process of stopping? Because map tasks can call stop() with success = true
// and then call stop() with success = false if they get an exception, we want to make sure
Expand All @@ -104,52 +104,55 @@ private[spark] class UnsafeShuffleWriter[K, V](

private[this] val blockManager = SparkEnv.get.blockManager

/** Write a sequence of records to this task's output */
override def write(records: Iterator[_ <: Product2[K, V]]): Unit = {
println("Opened writer!")
val serializer = Serializer.getSerializer(dep.serializer).newInstance()
val partitioner = dep.partitioner
sorter = new UnsafeSorter(
private def sortRecords(records: Iterator[_ <: Product2[K, V]]): java.util.Iterator[KeyPointerAndPrefix] = {
val sorter = new UnsafeSorter(
context.taskMemoryManager(),
DummyRecordComparator,
PartitionerPrefixComputer,
PartitionerPrefixComparator,
4096 // initial size
)

// Pack records into data pages:
val serializer = Serializer.getSerializer(dep.serializer).newInstance()
val PAGE_SIZE = 1024 * 1024 * 1

var currentPage: MemoryBlock = memoryManager.allocatePage(PAGE_SIZE)
allocatedPages.add(currentPage)
var currentPagePosition: Long = currentPage.getBaseOffset

// TODO make this configurable
def ensureSpaceInDataPage(spaceRequired: Long): Unit = {
if (spaceRequired > PAGE_SIZE) {
throw new Exception(s"Size requirement $spaceRequired is greater than page size $PAGE_SIZE")
} else if (spaceRequired > (PAGE_SIZE - currentPagePosition)) {
currentPage = memoryManager.allocatePage(PAGE_SIZE)
allocatedPages.add(currentPage)
currentPagePosition = currentPage.getBaseOffset
}
}

// TODO: the size of this buffer should be configurable
val serArray = new Array[Byte](1024 * 1024)
val byteBuffer = ByteBuffer.wrap(serArray)
val bbos = new ByteBufferOutputStream()
bbos.setByteBuffer(byteBuffer)
val serBufferSerStream = serializer.serializeStream(bbos)

while (records.hasNext) {
val nextRecord: Product2[K, V] = records.next()
println("Writing record " + nextRecord)
val partitionId: Int = partitioner.getPartition(nextRecord._1)
serBufferSerStream.writeObject(nextRecord)

val sizeRequirement: Int = byteBuffer.position() + 8 + 8
println("Size requirement in intenral buffer is " + sizeRequirement)
if (sizeRequirement > (PAGE_SIZE - currentPagePosition)) {
println("Allocating a new data page after writing " + currentPagePosition)
currentPage = memoryManager.allocatePage(PAGE_SIZE)
allocatedPages.add(currentPage)
currentPagePosition = currentPage.getBaseOffset
}
println("Before writing record, current page position is " + currentPagePosition)
// TODO: check that it's still not too large
def writeRecord(record: Product2[Any, Any]): Unit = {
val (key, value) = record
val partitionId = partitioner.getPartition(key)
serBufferSerStream.writeKey(key)
serBufferSerStream.writeValue(value)
serBufferSerStream.flush()

val serializedRecordSize = byteBuffer.position()
// TODO: we should run the partition extraction function _now_, at insert time, rather than
// requiring it to be stored alongisde the data, since this may lead to double storage
val sizeRequirementInSortDataPage = serializedRecordSize + 8 + 8
ensureSpaceInDataPage(sizeRequirementInSortDataPage)

val newRecordAddress =
memoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition)
PlatformDependent.UNSAFE.putLong(currentPage.getBaseObject, currentPagePosition, partitionId)
currentPagePosition += 8
println("The stored record length is " + byteBuffer.position())
PlatformDependent.UNSAFE.putLong(
currentPage.getBaseObject, currentPagePosition, byteBuffer.position())
currentPagePosition += 8
Expand All @@ -162,45 +165,53 @@ private[spark] class UnsafeShuffleWriter[K, V](
currentPagePosition += byteBuffer.position()
println("After writing record, current page position is " + currentPagePosition)
sorter.insertRecord(newRecordAddress)

// Reset for writing the next record
byteBuffer.position(0)
}
// TODO: free the buffers, etc, at this point since they're not needed
val sortedIterator: util.Iterator[KeyPointerAndPrefix] = sorter.getSortedIterator
// Now that the partition is sorted, write out the data to a file, keeping track off offsets
// for use in the sort-based shuffle index.

while (records.hasNext) {
writeRecord(records.next())
}

sorter.getSortedIterator
}

private def writeSortedRecordsToFile(sortedRecords: java.util.Iterator[KeyPointerAndPrefix]): Array[Long] = {
val outputFile = shuffleBlockManager.getDataFile(dep.shuffleId, mapId)
val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockManager.NOOP_REDUCE_ID)
val partitionLengths = new Array[Long](partitioner.numPartitions)
// TODO: compression tests?
// TODO why is append true here?
// TODO: metrics tracking and all of the other stuff that diskblockobjectwriter would give us
// TODO: note that we saw FAILED_TO_UNCOMPRESS(5) at some points during debugging when we were
// not properly wrapping the writer for compression even though readers expected compressed
// data; the fact that someone still reported this issue in newer Spark versions suggests that
// we should audit the code to make sure wrapping is done at the right set of places and to
// check that we haven't missed any rare corner-cases / rarely-used paths.
val out = blockManager.wrapForCompression(blockId, new FileOutputStream(outputFile, true))
val serOut = serializer.serializeStream(out)
serOut.flush()

var currentPartition = -1
var currentPartitionLength: Long = 0
while (sortedIterator.hasNext) {
val keyPointerAndPrefix: KeyPointerAndPrefix = sortedIterator.next()
val partition = keyPointerAndPrefix.keyPrefix.toInt
println("Partition is " + partition)
if (currentPartition == -1) {
currentPartition = partition
var prevPartitionLength: Long = 0
var out: OutputStream = null

// TODO: don't close and re-open file handles so often; this could be inefficient

def closePartition(): Unit = {
out.flush()
out.close()
partitionLengths(currentPartition) = outputFile.length() - prevPartitionLength
}

def switchToPartition(newPartition: Int): Unit = {
if (currentPartition != -1) {
closePartition()
prevPartitionLength = partitionLengths(currentPartition)
}
currentPartition = newPartition
out = blockManager.wrapForCompression(blockId, new FileOutputStream(outputFile, true))
}

while (sortedRecords.hasNext) {
val keyPointerAndPrefix: KeyPointerAndPrefix = sortedRecords.next()
val partition = keyPointerAndPrefix.keyPrefix.toInt
if (partition != currentPartition) {
println("switching partition")
partitionLengths(currentPartition) = currentPartitionLength
currentPartitionLength = 0
currentPartition = partition
switchToPartition(partition)
}
val baseObject = memoryManager.getPage(keyPointerAndPrefix.recordPointer)
val baseOffset = memoryManager.getOffsetInPage(keyPointerAndPrefix.recordPointer)
val recordLength = PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + 8)
partitionLengths(currentPartition) += recordLength
println("Base offset is " + baseOffset)
println("Record length is " + recordLength)
var i: Int = 0
Expand All @@ -213,10 +224,19 @@ private[spark] class UnsafeShuffleWriter[K, V](
i += 1
}
}
out.flush()
//serOut.close()
//out.flush()
out.close()
closePartition()

partitionLengths
}

/** Write a sequence of records to this task's output */
override def write(records: Iterator[_ <: Product2[K, V]]): Unit = {
println("Opened writer!")

val sortedIterator = sortRecords(records)
val partitionLengths = writeSortedRecordsToFile(sortedIterator)

println("Partition lengths are " + partitionLengths.toSeq)
shuffleBlockManager.writeIndexFile(dep.shuffleId, mapId, partitionLengths)
mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
}
Expand All @@ -239,7 +259,7 @@ private[spark] class UnsafeShuffleWriter[K, V](
}
} finally {
// Clean up our sorter, which may have its own intermediate files
if (sorter != null) {
if (!allocatedPages.isEmpty) {
val iter = allocatedPages.iterator()
while (iter.hasNext) {
memoryManager.freePage(iter.next())
Expand All @@ -249,7 +269,6 @@ private[spark] class UnsafeShuffleWriter[K, V](
//sorter.stop()
context.taskMetrics().shuffleWriteMetrics.foreach(
_.incShuffleWriteTime(System.nanoTime - startTime))
sorter = null
}
}
}
Expand Down

0 comments on commit 8e3ec20

Please sign in to comment.