Skip to content

Commit

Permalink
Revise compression codec support in merger; test cross product of con…
Browse files Browse the repository at this point in the history
…figurations.
  • Loading branch information
JoshRosen committed May 11, 2015
1 parent b57c17f commit 1ef56c7
Show file tree
Hide file tree
Showing 4 changed files with 151 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.shuffle.unsafe;

import org.apache.spark.storage.BlockId;
import org.apache.spark.storage.TempShuffleBlockId;

import java.io.File;

Expand All @@ -27,9 +27,9 @@
final class SpillInfo {
final long[] partitionLengths;
final File file;
final BlockId blockId;
final TempShuffleBlockId blockId;

public SpillInfo(int numPartitions, File file, BlockId blockId) {
public SpillInfo(int numPartitions, File file, TempShuffleBlockId blockId) {
this.partitionLengths = new long[numPartitions];
this.file = file;
this.blockId = blockId;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ private SpillInfo writeSpillFile() throws IOException {
final Tuple2<TempShuffleBlockId, File> spilledFileInfo =
blockManager.diskBlockManager().createTempShuffleBlock();
final File file = spilledFileInfo._2();
final BlockId blockId = spilledFileInfo._1();
final TempShuffleBlockId blockId = spilledFileInfo._1();
final SpillInfo spillInfo = new SpillInfo(numPartitions, file, blockId);

// Unfortunately, we need a serializer instance in order to construct a DiskBlockObjectWriter.
Expand Down Expand Up @@ -320,7 +320,7 @@ private void allocateSpaceForRecord(int requiredSpace) throws IOException {
}
}
if (requiredSpace > freeSpaceInCurrentPage) {
logger.debug("Required space {} is less than free space in current page ({}}", requiredSpace,
logger.trace("Required space {} is less than free space in current page ({})", requiredSpace,
freeSpaceInCurrentPage);
// TODO: we should track metrics on the amount of space wasted when we roll over to a new page
// without using the free space at the end of the current page. We should also do this for
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.io.*;
import java.nio.channels.FileChannel;
import java.util.Iterator;
import javax.annotation.Nullable;

import scala.Option;
import scala.Product2;
Expand All @@ -35,6 +36,9 @@
import org.slf4j.LoggerFactory;

import org.apache.spark.*;
import org.apache.spark.io.CompressionCodec;
import org.apache.spark.io.CompressionCodec$;
import org.apache.spark.io.LZFCompressionCodec;
import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.network.util.LimitedInputStream;
import org.apache.spark.scheduler.MapStatus;
Expand All @@ -53,8 +57,6 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {

private final Logger logger = LoggerFactory.getLogger(UnsafeShuffleWriter.class);

@VisibleForTesting
static final int MAXIMUM_RECORD_SIZE = 1024 * 1024 * 64; // 64 megabytes
private static final ClassTag<Object> OBJECT_CLASS_TAG = ClassTag$.MODULE$.Object();

private final BlockManager blockManager;
Expand Down Expand Up @@ -201,6 +203,12 @@ void forceSorterToSpill() throws IOException {

private long[] mergeSpills(SpillInfo[] spills) throws IOException {
final File outputFile = shuffleBlockResolver.getDataFile(shuffleId, mapId);
final boolean compressionEnabled = sparkConf.getBoolean("spark.shuffle.compress", true);
final CompressionCodec compressionCodec = CompressionCodec$.MODULE$.createCodec(sparkConf);
final boolean fastMergeEnabled =
sparkConf.getBoolean("spark.shuffle.unsafe.fastMergeEnabled", true);
final boolean fastMergeIsSupported =
!compressionEnabled || compressionCodec instanceof LZFCompressionCodec;
try {
if (spills.length == 0) {
new FileOutputStream(outputFile).close(); // Create an empty file
Expand All @@ -215,11 +223,20 @@ private long[] mergeSpills(SpillInfo[] spills) throws IOException {
Files.move(spills[0].file, outputFile);
return spills[0].partitionLengths;
} else {
// Need to merge multiple spills.
if (transferToEnabled) {
return mergeSpillsWithTransferTo(spills, outputFile);
if (fastMergeEnabled && fastMergeIsSupported) {
// Compression is disabled or we are using an IO compression codec that supports
// decompression of concatenated compressed streams, so we can perform a fast spill merge
// that doesn't need to interpret the spilled bytes.
if (transferToEnabled) {
logger.debug("Using transferTo-based fast merge");
return mergeSpillsWithTransferTo(spills, outputFile);
} else {
logger.debug("Using fileStream-based fast merge");
return mergeSpillsWithFileStream(spills, outputFile, null);
}
} else {
return mergeSpillsWithFileStream(spills, outputFile);
logger.debug("Using slow merge");
return mergeSpillsWithFileStream(spills, outputFile, compressionCodec);
}
}
} catch (IOException e) {
Expand All @@ -230,27 +247,40 @@ private long[] mergeSpills(SpillInfo[] spills) throws IOException {
}
}

private long[] mergeSpillsWithFileStream(SpillInfo[] spills, File outputFile) throws IOException {
private long[] mergeSpillsWithFileStream(
SpillInfo[] spills,
File outputFile,
@Nullable CompressionCodec compressionCodec) throws IOException {
final int numPartitions = partitioner.numPartitions();
final long[] partitionLengths = new long[numPartitions];
final FileInputStream[] spillInputStreams = new FileInputStream[spills.length];
FileOutputStream mergedFileOutputStream = null;
final InputStream[] spillInputStreams = new FileInputStream[spills.length];
OutputStream mergedFileOutputStream = null;

try {
for (int i = 0; i < spills.length; i++) {
spillInputStreams[i] = new FileInputStream(spills[i].file);
}
mergedFileOutputStream = new FileOutputStream(outputFile);

for (int partition = 0; partition < numPartitions; partition++) {
final long initialFileLength = outputFile.length();
mergedFileOutputStream = new FileOutputStream(outputFile, true);
if (compressionCodec != null) {
mergedFileOutputStream = compressionCodec.compressedOutputStream(mergedFileOutputStream);
}

for (int i = 0; i < spills.length; i++) {
final long partitionLengthInSpill = spills[i].partitionLengths[partition];
final FileInputStream spillInputStream = spillInputStreams[i];
ByteStreams.copy
(new LimitedInputStream(spillInputStream, partitionLengthInSpill),
mergedFileOutputStream);
partitionLengths[partition] += partitionLengthInSpill;
if (partitionLengthInSpill > 0) {
InputStream partitionInputStream =
new LimitedInputStream(spillInputStreams[i], partitionLengthInSpill);
if (compressionCodec != null) {
partitionInputStream = compressionCodec.compressedInputStream(partitionInputStream);
}
ByteStreams.copy(partitionInputStream, mergedFileOutputStream);
}
}
mergedFileOutputStream.flush();
mergedFileOutputStream.close();
partitionLengths[partition] = (outputFile.length() - initialFileLength);
}
} finally {
for (InputStream stream : spillInputStreams) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,14 @@
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import static org.mockito.AdditionalAnswers.returnsFirstArg;
import static org.mockito.AdditionalAnswers.returnsSecondArg;
import static org.mockito.Answers.RETURNS_SMART_NULLS;
import static org.mockito.Mockito.*;

import org.apache.spark.*;
import org.apache.spark.io.CompressionCodec$;
import org.apache.spark.io.LZ4CompressionCodec;
import org.apache.spark.io.LZFCompressionCodec;
import org.apache.spark.io.SnappyCompressionCodec;
import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.executor.TaskMetrics;
import org.apache.spark.network.util.LimitedInputStream;
Expand All @@ -65,6 +68,7 @@ public class UnsafeShuffleWriterSuite {
File tempDir;
long[] partitionSizesInMergedFile;
final LinkedList<File> spillFilesCreated = new LinkedList<File>();
SparkConf conf;
final Serializer serializer = new KryoSerializer(new SparkConf());

@Mock(answer = RETURNS_SMART_NULLS) ShuffleMemoryManager shuffleMemoryManager;
Expand All @@ -74,10 +78,14 @@ public class UnsafeShuffleWriterSuite {
@Mock(answer = RETURNS_SMART_NULLS) TaskContext taskContext;
@Mock(answer = RETURNS_SMART_NULLS) ShuffleDependency<Object, Object, Object> shuffleDep;

private static final class CompressStream extends AbstractFunction1<OutputStream, OutputStream> {
private final class CompressStream extends AbstractFunction1<OutputStream, OutputStream> {
@Override
public OutputStream apply(OutputStream stream) {
return stream;
if (conf.getBoolean("spark.shuffle.compress", true)) {
return CompressionCodec$.MODULE$.createCodec(conf).compressedOutputStream(stream);
} else {
return stream;
}
}
}

Expand All @@ -98,6 +106,7 @@ public void setUp() throws IOException {
mergedOutputFile = File.createTempFile("mergedoutput", "", tempDir);
partitionSizesInMergedFile = null;
spillFilesCreated.clear();
conf = new SparkConf();

when(shuffleMemoryManager.tryToAcquire(anyLong())).then(returnsFirstArg());

Expand All @@ -123,8 +132,35 @@ public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Th
);
}
});
when(blockManager.wrapForCompression(any(BlockId.class), any(InputStream.class)))
.then(returnsSecondArg());
when(blockManager.wrapForCompression(any(BlockId.class), any(InputStream.class))).thenAnswer(
new Answer<InputStream>() {
@Override
public InputStream answer(InvocationOnMock invocation) throws Throwable {
assert (invocation.getArguments()[0] instanceof TempShuffleBlockId);
InputStream is = (InputStream) invocation.getArguments()[1];
if (conf.getBoolean("spark.shuffle.compress", true)) {
return CompressionCodec$.MODULE$.createCodec(conf).compressedInputStream(is);
} else {
return is;
}
}
}
);

when(blockManager.wrapForCompression(any(BlockId.class), any(OutputStream.class))).thenAnswer(
new Answer<OutputStream>() {
@Override
public OutputStream answer(InvocationOnMock invocation) throws Throwable {
assert (invocation.getArguments()[0] instanceof TempShuffleBlockId);
OutputStream os = (OutputStream) invocation.getArguments()[1];
if (conf.getBoolean("spark.shuffle.compress", true)) {
return CompressionCodec$.MODULE$.createCodec(conf).compressedOutputStream(os);
} else {
return os;
}
}
}
);

when(shuffleBlockResolver.getDataFile(anyInt(), anyInt())).thenReturn(mergedOutputFile);
doAnswer(new Answer<Void>() {
Expand All @@ -136,11 +172,11 @@ public Void answer(InvocationOnMock invocationOnMock) throws Throwable {
}).when(shuffleBlockResolver).writeIndexFile(anyInt(), anyInt(), any(long[].class));

when(diskBlockManager.createTempShuffleBlock()).thenAnswer(
new Answer<Tuple2<TempLocalBlockId, File>>() {
new Answer<Tuple2<TempShuffleBlockId, File>>() {
@Override
public Tuple2<TempLocalBlockId, File> answer(
public Tuple2<TempShuffleBlockId, File> answer(
InvocationOnMock invocationOnMock) throws Throwable {
TempLocalBlockId blockId = new TempLocalBlockId(UUID.randomUUID());
TempShuffleBlockId blockId = new TempShuffleBlockId(UUID.randomUUID());
File file = File.createTempFile("spillFile", ".spill", tempDir);
spillFilesCreated.add(file);
return Tuple2$.MODULE$.apply(blockId, file);
Expand All @@ -154,7 +190,6 @@ public Tuple2<TempLocalBlockId, File> answer(
}

private UnsafeShuffleWriter<Object, Object> createWriter(boolean transferToEnabled) {
SparkConf conf = new SparkConf();
conf.set("spark.file.transferTo", String.valueOf(transferToEnabled));
return new UnsafeShuffleWriter<Object, Object>(
blockManager,
Expand All @@ -164,7 +199,7 @@ private UnsafeShuffleWriter<Object, Object> createWriter(boolean transferToEnabl
new UnsafeShuffleHandle<Object, Object>(0, 1, shuffleDep),
0, // map id
taskContext,
new SparkConf()
conf
);
}

Expand All @@ -183,8 +218,11 @@ private List<Tuple2<Object, Object>> readRecordsFromFile() throws IOException {
if (partitionSize > 0) {
InputStream in = new FileInputStream(mergedOutputFile);
ByteStreams.skipFully(in, startOffset);
DeserializationStream recordsStream = serializer.newInstance().deserializeStream(
new LimitedInputStream(in, partitionSize));
in = new LimitedInputStream(in, partitionSize);
if (conf.getBoolean("spark.shuffle.compress", true)) {
in = CompressionCodec$.MODULE$.createCodec(conf).compressedInputStream(in);
}
DeserializationStream recordsStream = serializer.newInstance().deserializeStream(in);
Iterator<Tuple2<Object, Object>> records = recordsStream.asKeyValueIterator();
while (records.hasNext()) {
Tuple2<Object, Object> record = records.next();
Expand Down Expand Up @@ -245,7 +283,15 @@ public void writeWithoutSpilling() throws Exception {
assertSpillFilesWereCleanedUp();
}

private void testMergingSpills(boolean transferToEnabled) throws IOException {
private void testMergingSpills(
boolean transferToEnabled,
String compressionCodecName) throws IOException {
if (compressionCodecName != null) {
conf.set("spark.shuffle.compress", "true");
conf.set("spark.io.compression.codec", compressionCodecName);
} else {
conf.set("spark.shuffle.compress", "false");
}
final UnsafeShuffleWriter<Object, Object> writer = createWriter(transferToEnabled);
final ArrayList<Product2<Object, Object>> dataToWrite =
new ArrayList<Product2<Object, Object>>();
Expand All @@ -265,25 +311,57 @@ private void testMergingSpills(boolean transferToEnabled) throws IOException {
Assert.assertTrue(mergedOutputFile.exists());
Assert.assertEquals(2, spillFilesCreated.size());

long sumOfPartitionSizes = 0;
for (long size: partitionSizesInMergedFile) {
sumOfPartitionSizes += size;
}
Assert.assertEquals(mergedOutputFile.length(), sumOfPartitionSizes);
// This assertion only holds for the fast merging path:
// long sumOfPartitionSizes = 0;
// for (long size: partitionSizesInMergedFile) {
// sumOfPartitionSizes += size;
// }
// Assert.assertEquals(sumOfPartitionSizes, mergedOutputFile.length());
Assert.assertTrue(mergedOutputFile.length() > 0);
Assert.assertEquals(
HashMultiset.create(dataToWrite),
HashMultiset.create(readRecordsFromFile()));
assertSpillFilesWereCleanedUp();
}

@Test
public void mergeSpillsWithTransferTo() throws Exception {
testMergingSpills(true);
public void mergeSpillsWithTransferToAndLZF() throws Exception {
testMergingSpills(true, LZFCompressionCodec.class.getName());
}

@Test
public void mergeSpillsWithFileStreamAndLZF() throws Exception {
testMergingSpills(false, LZFCompressionCodec.class.getName());
}

@Test
public void mergeSpillsWithTransferToAndLZ4() throws Exception {
testMergingSpills(true, LZ4CompressionCodec.class.getName());
}

@Test
public void mergeSpillsWithFileStreamAndLZ4() throws Exception {
testMergingSpills(false, LZ4CompressionCodec.class.getName());
}

@Test
public void mergeSpillsWithTransferToAndSnappy() throws Exception {
testMergingSpills(true, SnappyCompressionCodec.class.getName());
}

@Test
public void mergeSpillsWithFileStreamAndSnappy() throws Exception {
testMergingSpills(false, SnappyCompressionCodec.class.getName());
}

@Test
public void mergeSpillsWithTransferToAndNoCompression() throws Exception {
testMergingSpills(true, null);
}

@Test
public void mergeSpillsWithFileStream() throws Exception {
testMergingSpills(false);
public void mergeSpillsWithFileStreamAndNoCompression() throws Exception {
testMergingSpills(false, null);
}

@Test
Expand Down

0 comments on commit 1ef56c7

Please sign in to comment.