Skip to content

Commit

Permalink
Begin refactoring to enable proper tests for spilling.
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshRosen committed May 9, 2015
1 parent 722849b commit 7cd013b
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
import java.io.IOException;
import java.util.LinkedList;

import org.apache.spark.storage.*;
import scala.Tuple2;

import com.google.common.annotations.VisibleForTesting;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -32,6 +32,7 @@
import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.serializer.SerializerInstance;
import org.apache.spark.shuffle.ShuffleMemoryManager;
import org.apache.spark.storage.*;
import org.apache.spark.unsafe.PlatformDependent;
import org.apache.spark.unsafe.memory.MemoryBlock;
import org.apache.spark.unsafe.memory.TaskMemoryManager;
Expand Down Expand Up @@ -215,7 +216,8 @@ private SpillInfo writeSpillFile() throws IOException {
/**
* Sort and spill the current records in response to memory pressure.
*/
private void spill() throws IOException {
@VisibleForTesting
void spill() throws IOException {
final long threadId = Thread.currentThread().getId();
logger.info("Thread " + threadId + " spilling sort data of " +
org.apache.spark.util.Utils.bytesToString(getMemoryUsage()) + " to disk (" +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import scala.reflect.ClassTag$;

import com.esotericsoftware.kryo.io.ByteBufferOutputStream;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.io.ByteStreams;
import com.google.common.io.Files;
import org.slf4j.Logger;
Expand Down Expand Up @@ -73,6 +74,11 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
private final boolean transferToEnabled;

private MapStatus mapStatus = null;
private UnsafeShuffleExternalSorter sorter = null;
private byte[] serArray = null;
private ByteBuffer serByteBuffer;
// TODO: we should not depend on this class from Kryo; copy its source or find an alternative
private SerializationStream serOutputStream;

/**
* Are we in the process of stopping? Because map tasks can call stop() with success = true
Expand Down Expand Up @@ -113,56 +119,72 @@ public void write(Iterator<Product2<K, V>> records) {
@Override
public void write(scala.collection.Iterator<Product2<K, V>> records) {
try {
final long[] partitionLengths = mergeSpills(insertRecordsIntoSorter(records));
shuffleBlockManager.writeIndexFile(shuffleId, mapId, partitionLengths);
mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
while (records.hasNext()) {
insertRecordIntoSorter(records.next());
}
closeAndWriteOutput();
} catch (Exception e) {
PlatformDependent.throwException(e);
}
}

private void freeMemory() {
// TODO
}

private void deleteSpills() {
// TODO
}

private SpillInfo[] insertRecordsIntoSorter(
scala.collection.Iterator<? extends Product2<K, V>> records) throws Exception {
final UnsafeShuffleExternalSorter sorter = new UnsafeShuffleExternalSorter(
private void open() throws IOException {
assert (sorter == null);
sorter = new UnsafeShuffleExternalSorter(
memoryManager,
shuffleMemoryManager,
blockManager,
taskContext,
4096, // Initial size (TODO: tune this!)
partitioner.numPartitions(),
sparkConf);

final byte[] serArray = new byte[SER_BUFFER_SIZE];
final ByteBuffer serByteBuffer = ByteBuffer.wrap(serArray);
serArray = new byte[SER_BUFFER_SIZE];
serByteBuffer = ByteBuffer.wrap(serArray);
// TODO: we should not depend on this class from Kryo; copy its source or find an alternative
final SerializationStream serOutputStream =
serializer.serializeStream(new ByteBufferOutputStream(serByteBuffer));
serOutputStream = serializer.serializeStream(new ByteBufferOutputStream(serByteBuffer));
}

while (records.hasNext()) {
final Product2<K, V> record = records.next();
final K key = record._1();
final int partitionId = partitioner.getPartition(key);
serByteBuffer.position(0);
serOutputStream.writeKey(key, OBJECT_CLASS_TAG);
serOutputStream.writeValue(record._2(), OBJECT_CLASS_TAG);
serOutputStream.flush();
@VisibleForTesting
void closeAndWriteOutput() throws IOException {
if (sorter == null) {
open();
}
serArray = null;
serByteBuffer = null;
serOutputStream = null;
final long[] partitionLengths = mergeSpills(sorter.closeAndGetSpills());
sorter = null;
shuffleBlockManager.writeIndexFile(shuffleId, mapId, partitionLengths);
mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
}

final int serializedRecordSize = serByteBuffer.position();
assert (serializedRecordSize > 0);
private void freeMemory() {
// TODO
}

sorter.insertRecord(
serArray, PlatformDependent.BYTE_ARRAY_OFFSET, serializedRecordSize, partitionId);
@VisibleForTesting
void insertRecordIntoSorter(Product2<K, V> record) throws IOException{
if (sorter == null) {
open();
}
final K key = record._1();
final int partitionId = partitioner.getPartition(key);
serByteBuffer.position(0);
serOutputStream.writeKey(key, OBJECT_CLASS_TAG);
serOutputStream.writeValue(record._2(), OBJECT_CLASS_TAG);
serOutputStream.flush();

return sorter.closeAndGetSpills();
final int serializedRecordSize = serByteBuffer.position();
assert (serializedRecordSize > 0);

sorter.insertRecord(
serArray, PlatformDependent.BYTE_ARRAY_OFFSET, serializedRecordSize, partitionId);
}

@VisibleForTesting
void forceSorterToSpill() throws IOException {
assert (sorter != null);
sorter.spill();
}

private long[] mergeSpills(SpillInfo[] spills) throws IOException {
Expand Down Expand Up @@ -222,6 +244,9 @@ private long[] mergeSpillsWithFileStream(SpillInfo[] spills, File outputFile) th
for (int i = 0; i < spills.length; i++) {
if (spillInputStreams[i] != null) {
spillInputStreams[i].close();
if (!spills[i].file.delete()) {
logger.error("Error while deleting spill file {}", spills[i]);
}
}
}
if (mergedFileOutputStream != null) {
Expand Down Expand Up @@ -282,6 +307,9 @@ private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) th
assert(spillInputChannelPositions[i] == spills[i].file.length());
if (spillInputChannels[i] != null) {
spillInputChannels[i].close();
if (!spills[i].file.delete()) {
logger.error("Error while deleting spill file {}", spills[i]);
}
}
}
if (mergedFileOutputChannel != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,13 +193,13 @@ public void writeEmptyIterator() throws Exception {
@Test
public void writeWithoutSpilling() throws Exception {
// In this example, each partition should have exactly one record:
final ArrayList<Product2<Object, Object>> datatToWrite =
final ArrayList<Product2<Object, Object>> dataToWrite =
new ArrayList<Product2<Object, Object>>();
for (int i = 0; i < NUM_PARTITITONS; i++) {
datatToWrite.add(new Tuple2<Object, Object>(i, i));
dataToWrite.add(new Tuple2<Object, Object>(i, i));
}
final UnsafeShuffleWriter<Object, Object> writer = createWriter(true);
writer.write(datatToWrite.iterator());
writer.write(dataToWrite.iterator());
final Option<MapStatus> mapStatus = writer.stop(true);
Assert.assertTrue(mapStatus.isDefined());
Assert.assertTrue(mergedOutputFile.exists());
Expand All @@ -215,7 +215,42 @@ public void writeWithoutSpilling() throws Exception {
assertSpillFilesWereCleanedUp();
}

private void testMergingSpills(boolean transferToEnabled) throws IOException {
final UnsafeShuffleWriter<Object, Object> writer = createWriter(true);
writer.insertRecordIntoSorter(new Tuple2<Object, Object>(1, 1));
writer.insertRecordIntoSorter(new Tuple2<Object, Object>(2, 2));
writer.insertRecordIntoSorter(new Tuple2<Object, Object>(3, 3));
writer.insertRecordIntoSorter(new Tuple2<Object, Object>(4, 4));
writer.forceSorterToSpill();
writer.insertRecordIntoSorter(new Tuple2<Object, Object>(4, 4));
writer.insertRecordIntoSorter(new Tuple2<Object, Object>(2, 2));
writer.closeAndWriteOutput();
final Option<MapStatus> mapStatus = writer.stop(true);
Assert.assertTrue(mapStatus.isDefined());
Assert.assertTrue(mergedOutputFile.exists());
Assert.assertEquals(2, spillFilesCreated.size());

long sumOfPartitionSizes = 0;
for (long size: partitionSizesInMergedFile) {
sumOfPartitionSizes += size;
}
Assert.assertEquals(mergedOutputFile.length(), sumOfPartitionSizes);

assertSpillFilesWereCleanedUp();
}

@Test
public void mergeSpillsWithTransferTo() throws Exception {
testMergingSpills(true);
}

@Test
public void mergeSpillsWithFileStream() throws Exception {
testMergingSpills(false);
}

// TODO: actually try to read the shuffle output?
// TODO: add a test that manually triggers spills in order to exercise the merging.
// }

}

0 comments on commit 7cd013b

Please sign in to comment.