diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java index ac94161d9f242..b995ef826aaf4 100644 --- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java @@ -17,15 +17,15 @@ package org.apache.spark.shuffle.unsafe; -import java.io.File; -import java.io.IOException; -import java.io.InputStream; -import java.io.OutputStream; +import java.io.*; import java.util.*; import scala.*; +import scala.collection.Iterator; import scala.runtime.AbstractFunction1; +import com.google.common.collect.HashMultiset; +import com.google.common.io.ByteStreams; import org.junit.After; import org.junit.Assert; import org.junit.Before; @@ -40,19 +40,21 @@ import static org.mockito.Mockito.*; import org.apache.spark.*; -import org.apache.spark.serializer.Serializer; -import org.apache.spark.shuffle.IndexShuffleBlockResolver; import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.executor.TaskMetrics; +import org.apache.spark.network.util.LimitedInputStream; +import org.apache.spark.scheduler.MapStatus; +import org.apache.spark.serializer.DeserializationStream; +import org.apache.spark.serializer.KryoSerializer; +import org.apache.spark.serializer.Serializer; import org.apache.spark.serializer.SerializerInstance; +import org.apache.spark.shuffle.IndexShuffleBlockResolver; import org.apache.spark.shuffle.ShuffleMemoryManager; import org.apache.spark.storage.*; import org.apache.spark.unsafe.memory.ExecutorMemoryManager; import org.apache.spark.unsafe.memory.MemoryAllocator; import org.apache.spark.unsafe.memory.TaskMemoryManager; import org.apache.spark.util.Utils; -import org.apache.spark.serializer.KryoSerializer; -import org.apache.spark.scheduler.MapStatus; public class UnsafeShuffleWriterSuite { @@ -64,6 +66,7 @@ public class UnsafeShuffleWriterSuite { File tempDir; long[] partitionSizesInMergedFile; final LinkedList spillFilesCreated = new LinkedList(); + final Serializer serializer = new KryoSerializer(new SparkConf()); @Mock(answer = RETURNS_SMART_NULLS) ShuffleMemoryManager shuffleMemoryManager; @Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager; @@ -147,8 +150,7 @@ public Tuple2 answer( when(taskContext.taskMetrics()).thenReturn(new TaskMetrics()); - when(shuffleDep.serializer()).thenReturn( - Option.apply(new KryoSerializer(new SparkConf()))); + when(shuffleDep.serializer()).thenReturn(Option.apply(serializer)); when(shuffleDep.partitioner()).thenReturn(hashPartitioner); } @@ -174,6 +176,27 @@ private void assertSpillFilesWereCleanedUp() { } } + private List> readRecordsFromFile() throws IOException { + final ArrayList> recordsList = new ArrayList>(); + long startOffset = 0; + for (int i = 0; i < NUM_PARTITITONS; i++) { + final long partitionSize = partitionSizesInMergedFile[i]; + if (partitionSize > 0) { + InputStream in = new FileInputStream(mergedOutputFile); + ByteStreams.skipFully(in, startOffset); + DeserializationStream recordsStream = serializer.newInstance().deserializeStream( + new LimitedInputStream(in, partitionSize)); + Iterator> records = recordsStream.asKeyValueIterator(); + while (records.hasNext()) { + recordsList.add(records.next()); + } + recordsStream.close(); + startOffset += partitionSize; + } + } + return recordsList; + } + @Test(expected=IllegalStateException.class) public void mustCallWriteBeforeSuccessfulStop() { createWriter(false).stop(true); @@ -215,19 +238,26 @@ public void writeWithoutSpilling() throws Exception { sumOfPartitionSizes += size; } Assert.assertEquals(mergedOutputFile.length(), sumOfPartitionSizes); - + Assert.assertEquals( + HashMultiset.create(dataToWrite), + HashMultiset.create(readRecordsFromFile())); assertSpillFilesWereCleanedUp(); } private void testMergingSpills(boolean transferToEnabled) throws IOException { final UnsafeShuffleWriter writer = createWriter(transferToEnabled); - writer.insertRecordIntoSorter(new Tuple2(1, 1)); - writer.insertRecordIntoSorter(new Tuple2(2, 2)); - writer.insertRecordIntoSorter(new Tuple2(3, 3)); - writer.insertRecordIntoSorter(new Tuple2(4, 4)); + final ArrayList> dataToWrite = + new ArrayList>(); + for (int i : new int[] { 1, 2, 3, 4, 4, 2 }) { + dataToWrite.add(new Tuple2(i, i)); + } + writer.insertRecordIntoSorter(dataToWrite.get(0)); + writer.insertRecordIntoSorter(dataToWrite.get(1)); + writer.insertRecordIntoSorter(dataToWrite.get(2)); + writer.insertRecordIntoSorter(dataToWrite.get(3)); writer.forceSorterToSpill(); - writer.insertRecordIntoSorter(new Tuple2(4, 4)); - writer.insertRecordIntoSorter(new Tuple2(2, 2)); + writer.insertRecordIntoSorter(dataToWrite.get(4)); + writer.insertRecordIntoSorter(dataToWrite.get(5)); writer.closeAndWriteOutput(); final Option mapStatus = writer.stop(true); Assert.assertTrue(mapStatus.isDefined()); @@ -239,7 +269,9 @@ private void testMergingSpills(boolean transferToEnabled) throws IOException { sumOfPartitionSizes += size; } Assert.assertEquals(mergedOutputFile.length(), sumOfPartitionSizes); - + Assert.assertEquals( + HashMultiset.create(dataToWrite), + HashMultiset.create(readRecordsFromFile())); assertSpillFilesWereCleanedUp(); } @@ -263,7 +295,4 @@ public void spillFilesAreDeletedWhenStoppingAfterError() throws IOException { writer.stop(false); assertSpillFilesWereCleanedUp(); } - - // TODO: actually try to read the shuffle output? - }