Skip to content

Commit

Permalink
Actually read data in UnsafeShuffleWriterSuite
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshRosen committed May 10, 2015
1 parent 1929a74 commit 01afc74
Showing 1 changed file with 50 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@

package org.apache.spark.shuffle.unsafe;

import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.*;
import java.util.*;

import scala.*;
import scala.collection.Iterator;
import scala.runtime.AbstractFunction1;

import com.google.common.collect.HashMultiset;
import com.google.common.io.ByteStreams;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
Expand All @@ -40,19 +40,21 @@
import static org.mockito.Mockito.*;

import org.apache.spark.*;
import org.apache.spark.serializer.Serializer;
import org.apache.spark.shuffle.IndexShuffleBlockResolver;
import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.executor.TaskMetrics;
import org.apache.spark.network.util.LimitedInputStream;
import org.apache.spark.scheduler.MapStatus;
import org.apache.spark.serializer.DeserializationStream;
import org.apache.spark.serializer.KryoSerializer;
import org.apache.spark.serializer.Serializer;
import org.apache.spark.serializer.SerializerInstance;
import org.apache.spark.shuffle.IndexShuffleBlockResolver;
import org.apache.spark.shuffle.ShuffleMemoryManager;
import org.apache.spark.storage.*;
import org.apache.spark.unsafe.memory.ExecutorMemoryManager;
import org.apache.spark.unsafe.memory.MemoryAllocator;
import org.apache.spark.unsafe.memory.TaskMemoryManager;
import org.apache.spark.util.Utils;
import org.apache.spark.serializer.KryoSerializer;
import org.apache.spark.scheduler.MapStatus;

public class UnsafeShuffleWriterSuite {

Expand All @@ -64,6 +66,7 @@ public class UnsafeShuffleWriterSuite {
File tempDir;
long[] partitionSizesInMergedFile;
final LinkedList<File> spillFilesCreated = new LinkedList<File>();
final Serializer serializer = new KryoSerializer(new SparkConf());

@Mock(answer = RETURNS_SMART_NULLS) ShuffleMemoryManager shuffleMemoryManager;
@Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager;
Expand Down Expand Up @@ -147,8 +150,7 @@ public Tuple2<TempLocalBlockId, File> answer(

when(taskContext.taskMetrics()).thenReturn(new TaskMetrics());

when(shuffleDep.serializer()).thenReturn(
Option.<Serializer>apply(new KryoSerializer(new SparkConf())));
when(shuffleDep.serializer()).thenReturn(Option.<Serializer>apply(serializer));
when(shuffleDep.partitioner()).thenReturn(hashPartitioner);
}

Expand All @@ -174,6 +176,27 @@ private void assertSpillFilesWereCleanedUp() {
}
}

private List<Tuple2<Object, Object>> readRecordsFromFile() throws IOException {
final ArrayList<Tuple2<Object, Object>> recordsList = new ArrayList<Tuple2<Object, Object>>();
long startOffset = 0;
for (int i = 0; i < NUM_PARTITITONS; i++) {
final long partitionSize = partitionSizesInMergedFile[i];
if (partitionSize > 0) {
InputStream in = new FileInputStream(mergedOutputFile);
ByteStreams.skipFully(in, startOffset);
DeserializationStream recordsStream = serializer.newInstance().deserializeStream(
new LimitedInputStream(in, partitionSize));
Iterator<Tuple2<Object, Object>> records = recordsStream.asKeyValueIterator();
while (records.hasNext()) {
recordsList.add(records.next());
}
recordsStream.close();
startOffset += partitionSize;
}
}
return recordsList;
}

@Test(expected=IllegalStateException.class)
public void mustCallWriteBeforeSuccessfulStop() {
createWriter(false).stop(true);
Expand Down Expand Up @@ -215,19 +238,26 @@ public void writeWithoutSpilling() throws Exception {
sumOfPartitionSizes += size;
}
Assert.assertEquals(mergedOutputFile.length(), sumOfPartitionSizes);

Assert.assertEquals(
HashMultiset.create(dataToWrite),
HashMultiset.create(readRecordsFromFile()));
assertSpillFilesWereCleanedUp();
}

private void testMergingSpills(boolean transferToEnabled) throws IOException {
final UnsafeShuffleWriter<Object, Object> writer = createWriter(transferToEnabled);
writer.insertRecordIntoSorter(new Tuple2<Object, Object>(1, 1));
writer.insertRecordIntoSorter(new Tuple2<Object, Object>(2, 2));
writer.insertRecordIntoSorter(new Tuple2<Object, Object>(3, 3));
writer.insertRecordIntoSorter(new Tuple2<Object, Object>(4, 4));
final ArrayList<Product2<Object, Object>> dataToWrite =
new ArrayList<Product2<Object, Object>>();
for (int i : new int[] { 1, 2, 3, 4, 4, 2 }) {
dataToWrite.add(new Tuple2<Object, Object>(i, i));
}
writer.insertRecordIntoSorter(dataToWrite.get(0));
writer.insertRecordIntoSorter(dataToWrite.get(1));
writer.insertRecordIntoSorter(dataToWrite.get(2));
writer.insertRecordIntoSorter(dataToWrite.get(3));
writer.forceSorterToSpill();
writer.insertRecordIntoSorter(new Tuple2<Object, Object>(4, 4));
writer.insertRecordIntoSorter(new Tuple2<Object, Object>(2, 2));
writer.insertRecordIntoSorter(dataToWrite.get(4));
writer.insertRecordIntoSorter(dataToWrite.get(5));
writer.closeAndWriteOutput();
final Option<MapStatus> mapStatus = writer.stop(true);
Assert.assertTrue(mapStatus.isDefined());
Expand All @@ -239,7 +269,9 @@ private void testMergingSpills(boolean transferToEnabled) throws IOException {
sumOfPartitionSizes += size;
}
Assert.assertEquals(mergedOutputFile.length(), sumOfPartitionSizes);

Assert.assertEquals(
HashMultiset.create(dataToWrite),
HashMultiset.create(readRecordsFromFile()));
assertSpillFilesWereCleanedUp();
}

Expand All @@ -263,7 +295,4 @@ public void spillFilesAreDeletedWhenStoppingAfterError() throws IOException {
writer.stop(false);
assertSpillFilesWereCleanedUp();
}

// TODO: actually try to read the shuffle output?

}

0 comments on commit 01afc74

Please sign in to comment.