Skip to content

Commit

Permalink
Refactor to use DiskBlockObjectWriter.
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshRosen committed May 1, 2015
1 parent 253f13e commit 9c6cf58
Showing 1 changed file with 30 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package org.apache.spark.shuffle.unsafe

import java.io.{FileOutputStream, OutputStream}
import java.nio.ByteBuffer
import java.util

Expand All @@ -29,7 +28,7 @@ import org.apache.spark.scheduler.MapStatus
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle._
import org.apache.spark.shuffle.sort.SortShuffleManager
import org.apache.spark.storage.ShuffleBlockId
import org.apache.spark.storage.{BlockObjectWriter, ShuffleBlockId}
import org.apache.spark.unsafe.PlatformDependent
import org.apache.spark.unsafe.memory.{MemoryBlock, TaskMemoryManager}
import org.apache.spark.unsafe.sort.UnsafeSorter
Expand Down Expand Up @@ -104,15 +103,21 @@ private[spark] class UnsafeShuffleWriter[K, V](

private[this] val blockManager = SparkEnv.get.blockManager

private def sortRecords(records: Iterator[_ <: Product2[K, V]]): java.util.Iterator[KeyPointerAndPrefix] = {
// Use getSizeAsKb (not bytes) to maintain backwards compatibility of on units are provided
private[this] val fileBufferSize =
SparkEnv.get.conf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024

private[this] val serializer = Serializer.getSerializer(dep.serializer).newInstance()

private def sortRecords(
records: Iterator[_ <: Product2[K, V]]): java.util.Iterator[KeyPointerAndPrefix] = {
val sorter = new UnsafeSorter(
context.taskMemoryManager(),
DummyRecordComparator,
PartitionerPrefixComputer,
PartitionerPrefixComparator,
4096 // initial size
)
val serializer = Serializer.getSerializer(dep.serializer).newInstance()
val PAGE_SIZE = 1024 * 1024 * 1

var currentPage: MemoryBlock = null
Expand Down Expand Up @@ -178,32 +183,31 @@ private[spark] class UnsafeShuffleWriter[K, V](
sorter.getSortedIterator
}

private def writeSortedRecordsToFile(sortedRecords: java.util.Iterator[KeyPointerAndPrefix]): Array[Long] = {
private def writeSortedRecordsToFile(
sortedRecords: java.util.Iterator[KeyPointerAndPrefix]): Array[Long] = {
val outputFile = shuffleBlockManager.getDataFile(dep.shuffleId, mapId)
val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockManager.NOOP_REDUCE_ID)
val partitionLengths = new Array[Long](partitioner.numPartitions)

var currentPartition = -1
var prevPartitionLength: Long = 0
var out: OutputStream = null
var writer: BlockObjectWriter = null

// TODO: don't close and re-open file handles so often; this could be inefficient

def closePartition(): Unit = {
out.flush()
out.close()
partitionLengths(currentPartition) = outputFile.length() - prevPartitionLength
writer.commitAndClose()
partitionLengths(currentPartition) = writer.fileSegment().length
}

def switchToPartition(newPartition: Int): Unit = {
assert (newPartition > currentPartition, s"new partition $newPartition should be >= $currentPartition")
assert (newPartition > currentPartition,
s"new partition $newPartition should be >= $currentPartition")
if (currentPartition != -1) {
closePartition()
prevPartitionLength = partitionLengths(currentPartition)
}
println(s"Before switching to partition $newPartition, partition lengths are " + partitionLengths.toSeq)
currentPartition = newPartition
out = blockManager.wrapForCompression(blockId, new FileOutputStream(outputFile, true))
writer =
blockManager.getDiskWriter(blockId, outputFile, serializer, fileBufferSize, writeMetrics)
}

while (sortedRecords.hasNext) {
Expand All @@ -214,18 +218,24 @@ private[spark] class UnsafeShuffleWriter[K, V](
}
val baseObject = memoryManager.getPage(keyPointerAndPrefix.recordPointer)
val baseOffset = memoryManager.getOffsetInPage(keyPointerAndPrefix.recordPointer)
val recordLength = PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + 8)
val recordLength: Int = PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + 8).toInt
println("Base offset is " + baseOffset)
println("Record length is " + recordLength)
// TODO: need to have a way to figure out whether a serializer supports relocation of
// serialized objects or not. Sandy also ran into this in his patch (see
// https://github.com/apache/spark/pull/4450). If we're using Java serialization, we might
// as well just bypass this optimized code path in favor of the old one.
var i: Int = 0
while (i < recordLength) {
out.write(PlatformDependent.UNSAFE.getByte(baseObject, baseOffset + 16 + i))
i += 1
}
// TODO: re-use a buffer or avoid double-buffering entirely
val arr: Array[Byte] = new Array[Byte](recordLength)
PlatformDependent.copyMemory(
baseObject,
baseOffset + 16,
arr,
PlatformDependent.BYTE_ARRAY_OFFSET,
recordLength)
writer.write(arr)
// TODO: add a test that detects whether we leave this call out:
writer.recordWritten()
}
closePartition()

Expand Down

0 comments on commit 9c6cf58

Please sign in to comment.