diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/DummySerializerInstance.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/DummySerializerInstance.java index 1d31a46993a22..3f746b886bc9b 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/DummySerializerInstance.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/DummySerializerInstance.java @@ -17,15 +17,18 @@ package org.apache.spark.shuffle.unsafe; -import org.apache.spark.serializer.DeserializationStream; -import org.apache.spark.serializer.SerializationStream; -import org.apache.spark.serializer.SerializerInstance; -import scala.reflect.ClassTag; - +import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.nio.ByteBuffer; +import scala.reflect.ClassTag; + +import org.apache.spark.serializer.DeserializationStream; +import org.apache.spark.serializer.SerializationStream; +import org.apache.spark.serializer.SerializerInstance; +import org.apache.spark.unsafe.PlatformDependent; + /** * Unfortunately, we need a serializer instance in order to construct a DiskBlockObjectWriter. * Our shuffle write path doesn't actually use this serializer (since we end up calling the @@ -39,10 +42,17 @@ final class DummySerializerInstance extends SerializerInstance { private DummySerializerInstance() { } @Override - public SerializationStream serializeStream(OutputStream s) { + public SerializationStream serializeStream(final OutputStream s) { return new SerializationStream() { @Override - public void flush() { } + public void flush() { + // Need to implement this because DiskObjectWriter uses it to flush the compression stream + try { + s.flush(); + } catch (IOException e) { + PlatformDependent.throwException(e); + } + } @Override public SerializationStream writeObject(T t, ClassTag ev1) { @@ -50,7 +60,14 @@ public SerializationStream writeObject(T t, ClassTag ev1) { } @Override - public void close() { } + public void close() { + // Need to implement this because DiskObjectWriter uses it to close the compression stream + try { + s.close(); + } catch (IOException e) { + PlatformDependent.throwException(e); + } + } }; } diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java index 511fdfa43d543..01bf7a5095970 100644 --- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java @@ -311,13 +311,12 @@ private void testMergingSpills( Assert.assertTrue(mergedOutputFile.exists()); Assert.assertEquals(2, spillFilesCreated.size()); - // This assertion only holds for the fast merging path: - // long sumOfPartitionSizes = 0; - // for (long size: partitionSizesInMergedFile) { - // sumOfPartitionSizes += size; - // } - // Assert.assertEquals(sumOfPartitionSizes, mergedOutputFile.length()); - Assert.assertTrue(mergedOutputFile.length() > 0); + long sumOfPartitionSizes = 0; + for (long size: partitionSizesInMergedFile) { + sumOfPartitionSizes += size; + } + Assert.assertEquals(sumOfPartitionSizes, mergedOutputFile.length()); + Assert.assertEquals( HashMultiset.create(dataToWrite), HashMultiset.create(readRecordsFromFile()));