diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java index 0167002ceedb8..2237ec052cac2 100644 --- a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java @@ -59,6 +59,10 @@ public interface ShuffleMapOutputWriter { * available to downstream reduce tasks. If this method throws any exception, this module's * {@link #abort(Throwable)} method will be invoked before propagating the exception. *

+ * Shuffle extensions which care about the cause of shuffle data corruption should store + * the checksums properly. When corruption happens, Spark would provide the checksum + * of the fetched partition to the shuffle extension to help diagnose the cause of corruption. + *

* This can also close any resources and clean up temporary state if necessary. *

* The returned commit message is a structure with two components: @@ -68,8 +72,11 @@ public interface ShuffleMapOutputWriter { * for that partition id. *

* 2) An optional metadata blob that can be used by shuffle readers. + * + * @param checksums The checksum values for each partition (where checksum index is equivalent to + * partition id) if shuffle checksum enabled. Otherwise, it's empty. */ - MapOutputCommitMessage commitAllPartitions() throws IOException; + MapOutputCommitMessage commitAllPartitions(long[] checksums) throws IOException; /** * Abort all of the writes done by any writers returned by {@link #getPartitionWriter(int)}. diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionWriter.java b/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionWriter.java index 928875156a70f..143cc6c871e5f 100644 --- a/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionWriter.java @@ -49,7 +49,7 @@ public interface ShufflePartitionWriter { * by the parent {@link ShuffleMapOutputWriter}. If one does so, ensure that * {@link OutputStream#close()} does not close the resource, since it will be reused across * partition writes. The underlying resources should be cleaned up in - * {@link ShuffleMapOutputWriter#commitAllPartitions()} and + * {@link ShuffleMapOutputWriter#commitAllPartitions(long[])} and * {@link ShuffleMapOutputWriter#abort(Throwable)}. */ OutputStream openStream() throws IOException; @@ -68,7 +68,7 @@ public interface ShufflePartitionWriter { * by the parent {@link ShuffleMapOutputWriter}. If one does so, ensure that * {@link WritableByteChannelWrapper#close()} does not close the resource, since the channel * will be reused across partition writes. The underlying resources should be cleaned up in - * {@link ShuffleMapOutputWriter#commitAllPartitions()} and + * {@link ShuffleMapOutputWriter#commitAllPartitions(long[])} and * {@link ShuffleMapOutputWriter#abort(Throwable)}. *

* This method is primarily for advanced optimizations where bytes can be copied from the input @@ -79,7 +79,7 @@ public interface ShufflePartitionWriter { * Note that the returned {@link WritableByteChannelWrapper} itself is closed, but not the * underlying channel that is returned by {@link WritableByteChannelWrapper#channel()}. Ensure * that the underlying channel is cleaned up in {@link WritableByteChannelWrapper#close()}, - * {@link ShuffleMapOutputWriter#commitAllPartitions()}, or + * {@link ShuffleMapOutputWriter#commitAllPartitions(long[])}, or * {@link ShuffleMapOutputWriter#abort(Throwable)}. */ default Optional openChannelWrapper() throws IOException { diff --git a/core/src/main/java/org/apache/spark/shuffle/api/SingleSpillShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/api/SingleSpillShuffleMapOutputWriter.java index cad8dcfda52bc..ba3d5a603e052 100644 --- a/core/src/main/java/org/apache/spark/shuffle/api/SingleSpillShuffleMapOutputWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/api/SingleSpillShuffleMapOutputWriter.java @@ -32,5 +32,8 @@ public interface SingleSpillShuffleMapOutputWriter { /** * Transfer a file that contains the bytes of all the partitions written by this map task. */ - void transferMapSpillFile(File mapOutputFile, long[] partitionLengths) throws IOException; + void transferMapSpillFile( + File mapOutputFile, + long[] partitionLengths, + long[] checksums) throws IOException; } diff --git a/core/src/main/java/org/apache/spark/shuffle/checksum/ShuffleChecksumHelper.java b/core/src/main/java/org/apache/spark/shuffle/checksum/ShuffleChecksumHelper.java new file mode 100644 index 0000000000000..a368836d2bb1d --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/checksum/ShuffleChecksumHelper.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.checksum; + +import java.util.zip.Adler32; +import java.util.zip.CRC32; +import java.util.zip.Checksum; + +import org.apache.spark.SparkConf; +import org.apache.spark.SparkException; +import org.apache.spark.annotation.Private; +import org.apache.spark.internal.config.package$; +import org.apache.spark.storage.ShuffleChecksumBlockId; + +/** + * A set of utility functions for the shuffle checksum. + */ +@Private +public class ShuffleChecksumHelper { + + /** Used when the checksum is disabled for shuffle. */ + private static final Checksum[] EMPTY_CHECKSUM = new Checksum[0]; + public static final long[] EMPTY_CHECKSUM_VALUE = new long[0]; + + public static boolean isShuffleChecksumEnabled(SparkConf conf) { + return (boolean) conf.get(package$.MODULE$.SHUFFLE_CHECKSUM_ENABLED()); + } + + public static Checksum[] createPartitionChecksumsIfEnabled(int numPartitions, SparkConf conf) + throws SparkException { + if (!isShuffleChecksumEnabled(conf)) { + return EMPTY_CHECKSUM; + } + + String checksumAlgo = shuffleChecksumAlgorithm(conf); + return getChecksumByAlgorithm(numPartitions, checksumAlgo); + } + + private static Checksum[] getChecksumByAlgorithm(int num, String algorithm) + throws SparkException { + Checksum[] checksums; + switch (algorithm) { + case "ADLER32": + checksums = new Adler32[num]; + for (int i = 0; i < num; i ++) { + checksums[i] = new Adler32(); + } + return checksums; + + case "CRC32": + checksums = new CRC32[num]; + for (int i = 0; i < num; i ++) { + checksums[i] = new CRC32(); + } + return checksums; + + default: + throw new SparkException("Unsupported shuffle checksum algorithm: " + algorithm); + } + } + + public static long[] getChecksumValues(Checksum[] partitionChecksums) { + int numPartitions = partitionChecksums.length; + long[] checksumValues = new long[numPartitions]; + for (int i = 0; i < numPartitions; i ++) { + checksumValues[i] = partitionChecksums[i].getValue(); + } + return checksumValues; + } + + public static String shuffleChecksumAlgorithm(SparkConf conf) { + return conf.get(package$.MODULE$.SHUFFLE_CHECKSUM_ALGORITHM()); + } + + public static Checksum getChecksumByFileExtension(String fileName) throws SparkException { + int index = fileName.lastIndexOf("."); + String algorithm = fileName.substring(index + 1); + return getChecksumByAlgorithm(1, algorithm)[0]; + } + + public static String getChecksumFileName(ShuffleChecksumBlockId blockId, SparkConf conf) { + // append the shuffle checksum algorithm as the file extension + return String.format("%s.%s", blockId.name(), shuffleChecksumAlgorithm(conf)); + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index 3dbee1b13d287..322224053df09 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -23,6 +23,7 @@ import java.io.OutputStream; import java.nio.channels.FileChannel; import java.util.Optional; +import java.util.zip.Checksum; import javax.annotation.Nullable; import scala.None$; @@ -38,6 +39,7 @@ import org.apache.spark.Partitioner; import org.apache.spark.ShuffleDependency; import org.apache.spark.SparkConf; +import org.apache.spark.SparkException; import org.apache.spark.shuffle.api.ShuffleExecutorComponents; import org.apache.spark.shuffle.api.ShuffleMapOutputWriter; import org.apache.spark.shuffle.api.ShufflePartitionWriter; @@ -49,6 +51,7 @@ import org.apache.spark.serializer.SerializerInstance; import org.apache.spark.shuffle.ShuffleWriteMetricsReporter; import org.apache.spark.shuffle.ShuffleWriter; +import org.apache.spark.shuffle.checksum.ShuffleChecksumHelper; import org.apache.spark.storage.*; import org.apache.spark.util.Utils; @@ -93,6 +96,8 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter { private FileSegment[] partitionWriterSegments; @Nullable private MapStatus mapStatus; private long[] partitionLengths; + /** Checksum calculator for each partition. Empty when shuffle checksum disabled. */ + private final Checksum[] partitionChecksums; /** * Are we in the process of stopping? Because map tasks can call stop() with success = true @@ -107,7 +112,7 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter { long mapId, SparkConf conf, ShuffleWriteMetricsReporter writeMetrics, - ShuffleExecutorComponents shuffleExecutorComponents) { + ShuffleExecutorComponents shuffleExecutorComponents) throws SparkException { // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided this.fileBufferSize = (int) (long) conf.get(package$.MODULE$.SHUFFLE_FILE_BUFFER_SIZE()) * 1024; this.transferToEnabled = conf.getBoolean("spark.file.transferTo", true); @@ -120,6 +125,8 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter { this.writeMetrics = writeMetrics; this.serializer = dep.serializer(); this.shuffleExecutorComponents = shuffleExecutorComponents; + this.partitionChecksums = + ShuffleChecksumHelper.createPartitionChecksumsIfEnabled(numPartitions, conf); } @Override @@ -129,7 +136,8 @@ public void write(Iterator> records) throws IOException { .createMapOutputWriter(shuffleId, mapId, numPartitions); try { if (!records.hasNext()) { - partitionLengths = mapOutputWriter.commitAllPartitions().getPartitionLengths(); + partitionLengths = mapOutputWriter.commitAllPartitions( + ShuffleChecksumHelper.EMPTY_CHECKSUM_VALUE).getPartitionLengths(); mapStatus = MapStatus$.MODULE$.apply( blockManager.shuffleServerId(), partitionLengths, mapId); return; @@ -143,8 +151,12 @@ public void write(Iterator> records) throws IOException { blockManager.diskBlockManager().createTempShuffleBlock(); final File file = tempShuffleBlockIdPlusFile._2(); final BlockId blockId = tempShuffleBlockIdPlusFile._1(); - partitionWriters[i] = - blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, writeMetrics); + DiskBlockObjectWriter writer = + blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, writeMetrics); + if (partitionChecksums.length > 0) { + writer.setChecksum(partitionChecksums[i]); + } + partitionWriters[i] = writer; } // Creating the file to write to and creating a disk writer both involve interacting with // the disk, and can take a long time in aggregate when we open many files, so should be @@ -218,7 +230,9 @@ private long[] writePartitionedData(ShuffleMapOutputWriter mapOutputWriter) thro } partitionWriters = null; } - return mapOutputWriter.commitAllPartitions().getPartitionLengths(); + return mapOutputWriter.commitAllPartitions( + ShuffleChecksumHelper.getChecksumValues(partitionChecksums) + ).getPartitionLengths(); } private void writePartitionedDataWithChannel( diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java index 833744f4777ce..0307027c6f264 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java @@ -21,7 +21,9 @@ import java.io.File; import java.io.IOException; import java.util.LinkedList; +import java.util.zip.Checksum; +import org.apache.spark.SparkException; import scala.Tuple2; import com.google.common.annotations.VisibleForTesting; @@ -39,6 +41,7 @@ import org.apache.spark.serializer.DummySerializerInstance; import org.apache.spark.serializer.SerializerInstance; import org.apache.spark.shuffle.ShuffleWriteMetricsReporter; +import org.apache.spark.shuffle.checksum.ShuffleChecksumHelper; import org.apache.spark.storage.BlockManager; import org.apache.spark.storage.DiskBlockObjectWriter; import org.apache.spark.storage.FileSegment; @@ -107,6 +110,9 @@ final class ShuffleExternalSorter extends MemoryConsumer { @Nullable private MemoryBlock currentPage = null; private long pageCursor = -1; + // Checksum calculator for each partition. Empty when shuffle checksum disabled. + private final Checksum[] partitionChecksums; + ShuffleExternalSorter( TaskMemoryManager memoryManager, BlockManager blockManager, @@ -114,7 +120,7 @@ final class ShuffleExternalSorter extends MemoryConsumer { int initialSize, int numPartitions, SparkConf conf, - ShuffleWriteMetricsReporter writeMetrics) { + ShuffleWriteMetricsReporter writeMetrics) throws SparkException { super(memoryManager, (int) Math.min(PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES, memoryManager.pageSizeBytes()), memoryManager.getTungstenMemoryMode()); @@ -133,6 +139,12 @@ final class ShuffleExternalSorter extends MemoryConsumer { this.peakMemoryUsedBytes = getMemoryUsage(); this.diskWriteBufferSize = (int) (long) conf.get(package$.MODULE$.SHUFFLE_DISK_WRITE_BUFFER_SIZE()); + this.partitionChecksums = + ShuffleChecksumHelper.createPartitionChecksumsIfEnabled(numPartitions, conf); + } + + public long[] getChecksums() { + return ShuffleChecksumHelper.getChecksumValues(partitionChecksums); } /** @@ -204,6 +216,9 @@ private void writeSortedFile(boolean isLastFile) { spillInfo.partitionLengths[currentPartition] = fileSegment.length(); } currentPartition = partition; + if (partitionChecksums.length > 0) { + writer.setChecksum(partitionChecksums[currentPartition]); + } } final long recordPointer = sortedRecords.packedRecordPointer.getRecordPointer(); diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index e8f94ba8ffeee..2659b172bf68c 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -57,6 +57,7 @@ import org.apache.spark.shuffle.api.ShufflePartitionWriter; import org.apache.spark.shuffle.api.SingleSpillShuffleMapOutputWriter; import org.apache.spark.shuffle.api.WritableByteChannelWrapper; +import org.apache.spark.shuffle.checksum.ShuffleChecksumHelper; import org.apache.spark.storage.BlockManager; import org.apache.spark.storage.TimeTrackingOutputStream; import org.apache.spark.unsafe.Platform; @@ -115,7 +116,7 @@ public UnsafeShuffleWriter( TaskContext taskContext, SparkConf sparkConf, ShuffleWriteMetricsReporter writeMetrics, - ShuffleExecutorComponents shuffleExecutorComponents) { + ShuffleExecutorComponents shuffleExecutorComponents) throws SparkException { final int numPartitions = handle.dependency().partitioner().numPartitions(); if (numPartitions > SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE()) { throw new IllegalArgumentException( @@ -198,7 +199,7 @@ public void write(scala.collection.Iterator> records) throws IOEx } } - private void open() { + private void open() throws SparkException { assert (sorter == null); sorter = new ShuffleExternalSorter( memoryManager, @@ -219,10 +220,10 @@ void closeAndWriteOutput() throws IOException { serBuffer = null; serOutputStream = null; final SpillInfo[] spills = sorter.closeAndGetSpills(); - sorter = null; try { partitionLengths = mergeSpills(spills); } finally { + sorter = null; for (SpillInfo spill : spills) { if (spill.file.exists() && !spill.file.delete()) { logger.error("Error while deleting spill file {}", spill.file.getPath()); @@ -267,7 +268,8 @@ private long[] mergeSpills(SpillInfo[] spills) throws IOException { if (spills.length == 0) { final ShuffleMapOutputWriter mapWriter = shuffleExecutorComponents .createMapOutputWriter(shuffleId, mapId, partitioner.numPartitions()); - return mapWriter.commitAllPartitions().getPartitionLengths(); + return mapWriter.commitAllPartitions( + ShuffleChecksumHelper.EMPTY_CHECKSUM_VALUE).getPartitionLengths(); } else if (spills.length == 1) { Optional maybeSingleFileWriter = shuffleExecutorComponents.createSingleFileMapOutputWriter(shuffleId, mapId); @@ -277,7 +279,8 @@ private long[] mergeSpills(SpillInfo[] spills) throws IOException { partitionLengths = spills[0].partitionLengths; logger.debug("Merge shuffle spills for mapId {} with length {}", mapId, partitionLengths.length); - maybeSingleFileWriter.get().transferMapSpillFile(spills[0].file, partitionLengths); + maybeSingleFileWriter.get() + .transferMapSpillFile(spills[0].file, partitionLengths, sorter.getChecksums()); } else { partitionLengths = mergeSpillsUsingStandardWriter(spills); } @@ -330,7 +333,7 @@ private long[] mergeSpillsUsingStandardWriter(SpillInfo[] spills) throws IOExcep // to be counted as shuffle write, but this will lead to double-counting of the final // SpillInfo's bytes. writeMetrics.decBytesWritten(spills[spills.length - 1].file.length()); - partitionLengths = mapWriter.commitAllPartitions().getPartitionLengths(); + partitionLengths = mapWriter.commitAllPartitions(sorter.getChecksums()).getPartitionLengths(); } catch (Exception e) { try { mapWriter.abort(e); diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java index 0b286264be43d..6c5025d1822f8 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java @@ -98,7 +98,7 @@ public ShufflePartitionWriter getPartitionWriter(int reducePartitionId) throws I } @Override - public MapOutputCommitMessage commitAllPartitions() throws IOException { + public MapOutputCommitMessage commitAllPartitions(long[] checksums) throws IOException { // Check the position after transferTo loop to see if it is in the right position and raise a // exception if it is incorrect. The position will not be increased to the expected length // after calling transferTo in kernel version 2.6.32. This issue is described at @@ -115,7 +115,8 @@ public MapOutputCommitMessage commitAllPartitions() throws IOException { File resolvedTmp = outputTempFile != null && outputTempFile.isFile() ? outputTempFile : null; log.debug("Writing shuffle index file for mapId {} with length {}", mapId, partitionLengths.length); - blockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, resolvedTmp); + blockResolver + .writeMetadataFileAndCommit(shuffleId, mapId, partitionLengths, checksums, resolvedTmp); return MapOutputCommitMessage.of(partitionLengths); } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskSingleSpillMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskSingleSpillMapOutputWriter.java index c8b41992a8919..6a994b49d3a29 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskSingleSpillMapOutputWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskSingleSpillMapOutputWriter.java @@ -44,12 +44,14 @@ public LocalDiskSingleSpillMapOutputWriter( @Override public void transferMapSpillFile( File mapSpillFile, - long[] partitionLengths) throws IOException { + long[] partitionLengths, + long[] checksums) throws IOException { // The map spill file already has the proper format, and it contains all of the partition data. // So just transfer it directly to the destination without any merging. File outputFile = blockResolver.getDataFile(shuffleId, mapId); File tempFile = Utils.tempFileWith(outputFile); Files.move(mapSpillFile.toPath(), tempFile.toPath()); - blockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tempFile); + blockResolver + .writeMetadataFileAndCommit(shuffleId, mapId, partitionLengths, checksums, tempFile); } } diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 613a66d2d5aca..3ef964fcb8fd9 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -1368,6 +1368,25 @@ package object config { s"The buffer size must be greater than 0 and less than or equal to ${Int.MaxValue}.") .createWithDefault(4096) + private[spark] val SHUFFLE_CHECKSUM_ENABLED = + ConfigBuilder("spark.shuffle.checksum.enabled") + .doc("Whether to calculate the checksum of shuffle output. If enabled, Spark will try " + + "its best to tell if shuffle data corruption is caused by network or disk or others.") + .version("3.3.0") + .booleanConf + .createWithDefault(true) + + private[spark] val SHUFFLE_CHECKSUM_ALGORITHM = + ConfigBuilder("spark.shuffle.checksum.algorithm") + .doc("The algorithm used to calculate the checksum. Currently, it only supports" + + " built-in algorithms of JDK.") + .version("3.3.0") + .stringConf + .transform(_.toUpperCase(Locale.ROOT)) + .checkValue(Set("ADLER32", "CRC32").contains, "Shuffle checksum algorithm " + + "should be either Adler32 or CRC32.") + .createWithDefault("ADLER32") + private[spark] val SHUFFLE_COMPRESS = ConfigBuilder("spark.shuffle.compress") .doc("Whether to compress shuffle output. Compression will use " + diff --git a/core/src/main/scala/org/apache/spark/io/MutableCheckedOutputStream.scala b/core/src/main/scala/org/apache/spark/io/MutableCheckedOutputStream.scala new file mode 100644 index 0000000000000..754b4a87720a9 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/io/MutableCheckedOutputStream.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.io + +import java.io.OutputStream +import java.util.zip.Checksum + +/** + * A variant of [[java.util.zip.CheckedOutputStream]] which can + * change the checksum calculator at runtime. + */ +class MutableCheckedOutputStream(out: OutputStream) extends OutputStream { + private var checksum: Checksum = _ + + def setChecksum(c: Checksum): Unit = { + this.checksum = c + } + + override def write(b: Int): Unit = { + assert(checksum != null, "Checksum is not set.") + checksum.update(b) + out.write(b) + } + + override def write(b: Array[Byte], off: Int, len: Int): Unit = { + assert(checksum != null, "Checksum is not set.") + checksum.update(b, off, len) + out.write(b, off, len) + } + + override def flush(): Unit = out.flush() + + override def close(): Unit = out.close() +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index 5d1da19e69f42..9c50569c78924 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -22,6 +22,8 @@ import java.nio.ByteBuffer import java.nio.channels.Channels import java.nio.file.Files +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.{SparkConf, SparkEnv, SparkException} import org.apache.spark.internal.{config, Logging} import org.apache.spark.io.NioBufferedFileInputStream @@ -31,6 +33,7 @@ import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle.{ExecutorDiskUtils, MergedBlockMeta} import org.apache.spark.serializer.SerializerManager import org.apache.spark.shuffle.IndexShuffleBlockResolver.NOOP_REDUCE_ID +import org.apache.spark.shuffle.checksum.ShuffleChecksumHelper import org.apache.spark.storage._ import org.apache.spark.util.Utils @@ -142,17 +145,18 @@ private[spark] class IndexShuffleBlockResolver( */ def removeDataByMap(shuffleId: Int, mapId: Long): Unit = { var file = getDataFile(shuffleId, mapId) - if (file.exists()) { - if (!file.delete()) { - logWarning(s"Error deleting data ${file.getPath()}") - } + if (file.exists() && !file.delete()) { + logWarning(s"Error deleting data ${file.getPath()}") } file = getIndexFile(shuffleId, mapId) - if (file.exists()) { - if (!file.delete()) { - logWarning(s"Error deleting index ${file.getPath()}") - } + if (file.exists() && !file.delete()) { + logWarning(s"Error deleting index ${file.getPath()}") + } + + file = getChecksumFile(shuffleId, mapId) + if (file.exists() && !file.delete()) { + logWarning(s"Error deleting checksum ${file.getPath()}") } } @@ -303,22 +307,41 @@ private[spark] class IndexShuffleBlockResolver( /** - * Write an index file with the offsets of each block, plus a final offset at the end for the - * end of the output file. This will be used by getBlockData to figure out where each block - * begins and ends. + * Commit the data and metadata files as an atomic operation, use the existing ones, or + * replace them with new ones. Note that the metadata parameters (`lengths`, `checksums`) + * will be updated to match the existing ones if use the existing ones. + * + * There're two kinds of metadata files: * - * It will commit the data and index file as an atomic operation, use the existing ones, or - * replace them with new ones. + * - index file + * An index file contains the offsets of each block, plus a final offset at the end + * for the end of the output file. It will be used by [[getBlockData]] to figure out + * where each block begins and ends. * - * Note: the `lengths` will be updated to match the existing index file if use the existing ones. + * - checksum file (optional) + * An checksum file contains the checksum of each block. It will be used to diagnose + * the cause when a block is corrupted. Note that empty `checksums` indicate that + * checksum is disabled. */ - def writeIndexFileAndCommit( + def writeMetadataFileAndCommit( shuffleId: Int, mapId: Long, lengths: Array[Long], + checksums: Array[Long], dataTmp: File): Unit = { val indexFile = getIndexFile(shuffleId, mapId) val indexTmp = Utils.tempFileWith(indexFile) + + val checksumEnabled = checksums.nonEmpty + val (checksumFileOpt, checksumTmpOpt) = if (checksumEnabled) { + assert(lengths.length == checksums.length, + "The size of partition lengths and checksums should be equal") + val checksumFile = getChecksumFile(shuffleId, mapId) + (Some(checksumFile), Some(Utils.tempFileWith(checksumFile))) + } else { + (None, None) + } + try { val dataFile = getDataFile(shuffleId, mapId) // There is only one IndexShuffleBlockResolver per executor, this synchronization make sure @@ -329,37 +352,47 @@ private[spark] class IndexShuffleBlockResolver( // Another attempt for the same task has already written our map outputs successfully, // so just use the existing partition lengths and delete our temporary map outputs. System.arraycopy(existingLengths, 0, lengths, 0, lengths.length) + if (checksumEnabled) { + val existingChecksums = getChecksums(checksumFileOpt.get, checksums.length) + if (existingChecksums != null) { + System.arraycopy(existingChecksums, 0, checksums, 0, lengths.length) + } else { + // It's possible that the previous task attempt succeeded writing the + // index file and data file but failed to write the checksum file. In + // this case, the current task attempt could write the missing checksum + // file by itself. + writeMetadataFile(checksums, checksumTmpOpt.get, checksumFileOpt.get, false) + } + } if (dataTmp != null && dataTmp.exists()) { dataTmp.delete() } } else { // This is the first successful attempt in writing the map outputs for this task, // so override any existing index and data files with the ones we wrote. - val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexTmp))) - Utils.tryWithSafeFinally { - // We take in lengths of each block, need to convert it to offsets. - var offset = 0L - out.writeLong(offset) - for (length <- lengths) { - offset += length - out.writeLong(offset) - } - } { - out.close() - } - if (indexFile.exists()) { - indexFile.delete() - } + val offsets = lengths.scanLeft(0L)(_ + _) + writeMetadataFile(offsets, indexTmp, indexFile, true) + if (dataFile.exists()) { dataFile.delete() } - if (!indexTmp.renameTo(indexFile)) { - throw new IOException("fail to rename file " + indexTmp + " to " + indexFile) - } if (dataTmp != null && dataTmp.exists() && !dataTmp.renameTo(dataFile)) { throw new IOException("fail to rename file " + dataTmp + " to " + dataFile) } + + // write the checksum file + checksumTmpOpt.zip(checksumFileOpt).foreach { case (checksumTmp, checksumFile) => + try { + writeMetadataFile(checksums, checksumTmp, checksumFile, false) + } catch { + case e: Exception => + // It's not worthwhile to fail here after index file and data file are + // already successfully stored since checksum is only a best-effort for + // the corner error case. + logError("Failed to write checksum file", e) + } + } } } } finally { @@ -367,6 +400,63 @@ private[spark] class IndexShuffleBlockResolver( if (indexTmp.exists() && !indexTmp.delete()) { logError(s"Failed to delete temporary index file at ${indexTmp.getAbsolutePath}") } + checksumTmpOpt.foreach { checksumTmp => + if (checksumTmp.exists()) { + try { + if (!checksumTmp.delete()) { + logError(s"Failed to delete temporary checksum file " + + s"at ${checksumTmp.getAbsolutePath}") + } + } catch { + case e: Exception => + // Unlike index deletion, we won't propagate the error for the checksum file since + // checksum is only a best-effort. + logError(s"Failed to delete temporary checksum file " + + s"at ${checksumTmp.getAbsolutePath}", e) + } + } + } + } + } + + /** + * Write the metadata file (index or checksum). Metadata values will be firstly write into + * the tmp file and the tmp file will be renamed to the target file at the end to avoid dirty + * writes. + * @param metaValues The metadata values + * @param tmpFile The temp file + * @param targetFile The target file + * @param propagateError Whether to propagate the error for file operation. Unlike index file, + * checksum is only a best-effort so we won't fail the whole task due to + * the error from checksum. + */ + private def writeMetadataFile( + metaValues: Array[Long], + tmpFile: File, + targetFile: File, + propagateError: Boolean): Unit = { + val out = new DataOutputStream( + new BufferedOutputStream( + new FileOutputStream(tmpFile) + ) + ) + Utils.tryWithSafeFinally { + metaValues.foreach(out.writeLong) + } { + out.close() + } + + if (targetFile.exists()) { + targetFile.delete() + } + + if (!tmpFile.renameTo(targetFile)) { + val errorMsg = s"fail to rename file $tmpFile to $targetFile" + if (propagateError) { + throw new IOException(errorMsg) + } else { + logWarning(errorMsg) + } } } @@ -414,6 +504,45 @@ private[spark] class IndexShuffleBlockResolver( new MergedBlockMeta(numChunks, chunkBitMaps) } + private[shuffle] def getChecksums(checksumFile: File, blockNum: Int): Array[Long] = { + if (!checksumFile.exists()) return null + val checksums = new ArrayBuffer[Long] + // Read the checksums of blocks + var in: DataInputStream = null + try { + in = new DataInputStream(new NioBufferedFileInputStream(checksumFile)) + while (checksums.size < blockNum) { + checksums += in.readLong() + } + } catch { + case _: IOException | _: EOFException => + return null + } finally { + in.close() + } + + checksums.toArray + } + + /** + * Get the shuffle checksum file. + * + * When the dirs parameter is None then use the disk manager's local directories. Otherwise, + * read from the specified directories. + */ + def getChecksumFile( + shuffleId: Int, + mapId: Long, + dirs: Option[Array[String]] = None): File = { + val blockId = ShuffleChecksumBlockId(shuffleId, mapId, NOOP_REDUCE_ID) + val fileName = ShuffleChecksumHelper.getChecksumFileName(blockId, conf) + dirs + .map(ExecutorDiskUtils.getFile(_, blockManager.subDirsPerLocalDir, fileName)) + .getOrElse { + blockManager.diskBlockManager.getFile(fileName) + } + } + override def getBlockData( blockId: BlockId, dirs: Option[Array[String]]): ManagedBuffer = { diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShufflePartitionPairsWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/ShufflePartitionPairsWriter.scala index e0affb858c359..9843ae18bba77 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShufflePartitionPairsWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShufflePartitionPairsWriter.scala @@ -18,7 +18,9 @@ package org.apache.spark.shuffle import java.io.{Closeable, IOException, OutputStream} +import java.util.zip.Checksum +import org.apache.spark.io.MutableCheckedOutputStream import org.apache.spark.serializer.{SerializationStream, SerializerInstance, SerializerManager} import org.apache.spark.shuffle.api.ShufflePartitionWriter import org.apache.spark.storage.{BlockId, TimeTrackingOutputStream} @@ -34,7 +36,8 @@ private[spark] class ShufflePartitionPairsWriter( serializerManager: SerializerManager, serializerInstance: SerializerInstance, blockId: BlockId, - writeMetrics: ShuffleWriteMetricsReporter) + writeMetrics: ShuffleWriteMetricsReporter, + checksum: Checksum) extends PairsWriter with Closeable { private var isClosed = false @@ -44,6 +47,9 @@ private[spark] class ShufflePartitionPairsWriter( private var objOut: SerializationStream = _ private var numRecordsWritten = 0 private var curNumBytesWritten = 0L + // this would be only initialized when checksum != null, + // which indicates shuffle checksum is enabled. + private var checksumOutputStream: MutableCheckedOutputStream = _ override def write(key: Any, value: Any): Unit = { if (isClosed) { @@ -61,7 +67,12 @@ private[spark] class ShufflePartitionPairsWriter( try { partitionStream = partitionWriter.openStream timeTrackingStream = new TimeTrackingOutputStream(writeMetrics, partitionStream) - wrappedStream = serializerManager.wrapStream(blockId, timeTrackingStream) + if (checksum != null) { + checksumOutputStream = new MutableCheckedOutputStream(timeTrackingStream) + checksumOutputStream.setChecksum(checksum) + } + wrappedStream = serializerManager.wrapStream(blockId, + if (checksumOutputStream != null) checksumOutputStream else timeTrackingStream) objOut = serializerInstance.serializeStream(wrappedStream) } catch { case e: Exception => diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index adbe6eca5800b..3cbf30160efb8 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -68,7 +68,7 @@ private[spark] class SortShuffleWriter[K, V, C]( val mapOutputWriter = shuffleExecutorComponents.createMapOutputWriter( dep.shuffleId, mapId, dep.partitioner.numPartitions) sorter.writePartitionedMapOutput(dep.shuffleId, mapId, mapOutputWriter) - partitionLengths = mapOutputWriter.commitAllPartitions().getPartitionLengths + partitionLengths = mapOutputWriter.commitAllPartitions(sorter.getChecksums).getPartitionLengths mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, mapId) } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala index dc70a9af7e9c3..db5862dec2fbe 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala @@ -92,6 +92,12 @@ case class ShuffleIndexBlockId(shuffleId: Int, mapId: Long, reduceId: Int) exten override def name: String = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".index" } +@Since("3.3.0") +@DeveloperApi +case class ShuffleChecksumBlockId(shuffleId: Int, mapId: Long, reduceId: Int) extends BlockId { + override def name: String = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".checksum" +} + @Since("3.2.0") @DeveloperApi case class ShufflePushBlockId(shuffleId: Int, mapIndex: Int, reduceId: Int) extends BlockId { diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala index e55c09274cd9a..f5d8c0219dc83 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala @@ -19,8 +19,10 @@ package org.apache.spark.storage import java.io.{BufferedOutputStream, File, FileOutputStream, OutputStream} import java.nio.channels.{ClosedByInterruptException, FileChannel} +import java.util.zip.Checksum import org.apache.spark.internal.Logging +import org.apache.spark.io.MutableCheckedOutputStream import org.apache.spark.serializer.{SerializationStream, SerializerInstance, SerializerManager} import org.apache.spark.shuffle.ShuffleWriteMetricsReporter import org.apache.spark.util.Utils @@ -77,6 +79,11 @@ private[spark] class DiskBlockObjectWriter( private var streamOpen = false private var hasBeenClosed = false + // checksum related + private var checksumEnabled = false + private var checksumOutputStream: MutableCheckedOutputStream = _ + private var checksum: Checksum = _ + /** * Cursors used to represent positions in the file. * @@ -101,12 +108,30 @@ private[spark] class DiskBlockObjectWriter( */ private var numRecordsWritten = 0 + /** + * Set the checksum that the checksumOutputStream should use + */ + def setChecksum(checksum: Checksum): Unit = { + if (checksumOutputStream == null) { + this.checksumEnabled = true + this.checksum = checksum + } else { + checksumOutputStream.setChecksum(checksum) + } + } + private def initialize(): Unit = { fos = new FileOutputStream(file, true) channel = fos.getChannel() ts = new TimeTrackingOutputStream(writeMetrics, fos) + if (checksumEnabled) { + assert(this.checksum != null, "Checksum is not set") + checksumOutputStream = new MutableCheckedOutputStream(ts) + checksumOutputStream.setChecksum(checksum) + } class ManualCloseBufferedOutputStream - extends BufferedOutputStream(ts, bufferSize) with ManualCloseOutputStream + extends BufferedOutputStream(if (checksumEnabled) checksumOutputStream else ts, bufferSize) + with ManualCloseOutputStream mcs = new ManualCloseBufferedOutputStream } diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 1913637371e31..dba9e749a573c 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -31,6 +31,7 @@ import org.apache.spark.internal.{config, Logging} import org.apache.spark.serializer._ import org.apache.spark.shuffle.ShufflePartitionPairsWriter import org.apache.spark.shuffle.api.{ShuffleMapOutputWriter, ShufflePartitionWriter} +import org.apache.spark.shuffle.checksum.ShuffleChecksumHelper import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter, ShuffleBlockId} import org.apache.spark.util.{CompletionIterator, Utils => TryUtils} @@ -141,6 +142,11 @@ private[spark] class ExternalSorter[K, V, C]( private val forceSpillFiles = new ArrayBuffer[SpilledFile] @volatile private var readingIterator: SpillableIterator = null + private val partitionChecksums = + ShuffleChecksumHelper.createPartitionChecksumsIfEnabled(numPartitions, conf) + + def getChecksums: Array[Long] = ShuffleChecksumHelper.getChecksumValues(partitionChecksums) + // A comparator for keys K that orders them within a partition to allow aggregation or sorting. // Can be a partial ordering by hash code if a total ordering is not provided through by the // user. (A partial ordering means that equal keys have comparator.compare(k, k) = 0, but some @@ -762,7 +768,8 @@ private[spark] class ExternalSorter[K, V, C]( serializerManager, serInstance, blockId, - context.taskMetrics().shuffleWriteMetrics) + context.taskMetrics().shuffleWriteMetrics, + if (partitionChecksums.nonEmpty) partitionChecksums(partitionId) else null) while (it.hasNext && it.nextPartition() == partitionId) { it.writeNext(partitionPairsWriter) } @@ -786,7 +793,8 @@ private[spark] class ExternalSorter[K, V, C]( serializerManager, serInstance, blockId, - context.taskMetrics().shuffleWriteMetrics) + context.taskMetrics().shuffleWriteMetrics, + if (partitionChecksums.nonEmpty) partitionChecksums(id) else null) if (elements.hasNext) { for (elem <- elements) { partitionPairsWriter.write(elem._1, elem._2) diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index 5666bb3e5f140..cca3eb5b2cbc3 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -22,11 +22,11 @@ import java.nio.file.Files; import java.util.*; +import org.apache.spark.*; +import org.apache.spark.shuffle.ShuffleChecksumTestHelper; +import org.apache.spark.shuffle.checksum.ShuffleChecksumHelper; import org.mockito.stubbing.Answer; -import scala.Option; -import scala.Product2; -import scala.Tuple2; -import scala.Tuple2$; +import scala.*; import scala.collection.Iterator; import com.google.common.collect.HashMultiset; @@ -36,10 +36,6 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; -import org.apache.spark.HashPartitioner; -import org.apache.spark.ShuffleDependency; -import org.apache.spark.SparkConf; -import org.apache.spark.TaskContext; import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.executor.TaskMetrics; import org.apache.spark.io.CompressionCodec$; @@ -65,7 +61,7 @@ import static org.mockito.Answers.RETURNS_SMART_NULLS; import static org.mockito.Mockito.*; -public class UnsafeShuffleWriterSuite { +public class UnsafeShuffleWriterSuite implements ShuffleChecksumTestHelper { static final int DEFAULT_INITIAL_SORT_BUFFER_SIZE = 4096; static final int NUM_PARTITIONS = 4; @@ -138,7 +134,7 @@ public void setUp() throws Exception { Answer renameTempAnswer = invocationOnMock -> { partitionSizesInMergedFile = (long[]) invocationOnMock.getArguments()[2]; - File tmp = (File) invocationOnMock.getArguments()[3]; + File tmp = (File) invocationOnMock.getArguments()[4]; if (!mergedOutputFile.delete()) { throw new RuntimeException("Failed to delete old merged output file."); } @@ -152,11 +148,13 @@ public void setUp() throws Exception { doAnswer(renameTempAnswer) .when(shuffleBlockResolver) - .writeIndexFileAndCommit(anyInt(), anyLong(), any(long[].class), any(File.class)); + .writeMetadataFileAndCommit( + anyInt(), anyLong(), any(long[].class), any(long[].class), any(File.class)); doAnswer(renameTempAnswer) .when(shuffleBlockResolver) - .writeIndexFileAndCommit(anyInt(), anyLong(), any(long[].class), eq(null)); + .writeMetadataFileAndCommit( + anyInt(), anyLong(), any(long[].class), any(long[].class), eq(null)); when(diskBlockManager.createTempShuffleBlock()).thenAnswer(invocationOnMock -> { TempShuffleBlockId blockId = new TempShuffleBlockId(UUID.randomUUID()); @@ -171,7 +169,14 @@ public void setUp() throws Exception { when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager); } - private UnsafeShuffleWriter createWriter(boolean transferToEnabled) { + private UnsafeShuffleWriter createWriter(boolean transferToEnabled) + throws SparkException { + return createWriter(transferToEnabled, shuffleBlockResolver); + } + + private UnsafeShuffleWriter createWriter( + boolean transferToEnabled, + IndexShuffleBlockResolver blockResolver) throws SparkException { conf.set("spark.file.transferTo", String.valueOf(transferToEnabled)); return new UnsafeShuffleWriter<>( blockManager, @@ -181,7 +186,7 @@ private UnsafeShuffleWriter createWriter(boolean transferToEnabl taskContext, conf, taskContext.taskMetrics().shuffleWriteMetrics(), - new LocalDiskShuffleExecutorComponents(conf, blockManager, shuffleBlockResolver)); + new LocalDiskShuffleExecutorComponents(conf, blockManager, blockResolver)); } private void assertSpillFilesWereCleanedUp() { @@ -219,12 +224,12 @@ private List> readRecordsFromFile() throws IOException { } @Test(expected=IllegalStateException.class) - public void mustCallWriteBeforeSuccessfulStop() throws IOException { + public void mustCallWriteBeforeSuccessfulStop() throws IOException, SparkException { createWriter(false).stop(true); } @Test - public void doNotNeedToCallWriteBeforeUnsuccessfulStop() throws IOException { + public void doNotNeedToCallWriteBeforeUnsuccessfulStop() throws IOException, SparkException { createWriter(false).stop(false); } @@ -291,6 +296,69 @@ public void writeWithoutSpilling() throws Exception { assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.bytesWritten()); } + @Test + public void writeChecksumFileWithoutSpill() throws Exception { + IndexShuffleBlockResolver blockResolver = new IndexShuffleBlockResolver(conf, blockManager); + ShuffleChecksumBlockId checksumBlockId = + new ShuffleChecksumBlockId(0, 0, IndexShuffleBlockResolver.NOOP_REDUCE_ID()); + File checksumFile = new File(tempDir, + ShuffleChecksumHelper.getChecksumFileName(checksumBlockId, conf)); + File dataFile = new File(tempDir, "data"); + File indexFile = new File(tempDir, "index"); + when(diskBlockManager.getFile(checksumFile.getName())) + .thenReturn(checksumFile); + when(diskBlockManager.getFile(new ShuffleDataBlockId(shuffleDep.shuffleId(), 0, 0))) + .thenReturn(dataFile); + when(diskBlockManager.getFile(new ShuffleIndexBlockId(shuffleDep.shuffleId(), 0, 0))) + .thenReturn(indexFile); + + // In this example, each partition should have exactly one record: + final ArrayList> dataToWrite = new ArrayList<>(); + for (int i = 0; i < NUM_PARTITIONS; i ++) { + dataToWrite.add(new Tuple2<>(i, i)); + } + final UnsafeShuffleWriter writer1 = createWriter(true, blockResolver); + writer1.write(dataToWrite.iterator()); + writer1.stop(true); + assertTrue(checksumFile.exists()); + assertEquals(checksumFile.length(), 8 * NUM_PARTITIONS); + compareChecksums(NUM_PARTITIONS, checksumFile, dataFile, indexFile); + } + + @Test + public void writeChecksumFileWithSpill() throws Exception { + IndexShuffleBlockResolver blockResolver = new IndexShuffleBlockResolver(conf, blockManager); + ShuffleChecksumBlockId checksumBlockId = + new ShuffleChecksumBlockId(0, 0, IndexShuffleBlockResolver.NOOP_REDUCE_ID()); + File checksumFile = + new File(tempDir, ShuffleChecksumHelper.getChecksumFileName(checksumBlockId, conf)); + File dataFile = new File(tempDir, "data"); + File indexFile = new File(tempDir, "index"); + when(diskBlockManager.getFile(eq(checksumFile.getName()))).thenReturn(checksumFile); + when(diskBlockManager.getFile(new ShuffleDataBlockId(shuffleDep.shuffleId(), 0, 0))) + .thenReturn(dataFile); + when(diskBlockManager.getFile(new ShuffleIndexBlockId(shuffleDep.shuffleId(), 0, 0))) + .thenReturn(indexFile); + + final UnsafeShuffleWriter writer1 = createWriter(true, blockResolver); + writer1.insertRecordIntoSorter(new Tuple2<>(0, 0)); + writer1.forceSorterToSpill(); + writer1.insertRecordIntoSorter(new Tuple2<>(1, 0)); + writer1.insertRecordIntoSorter(new Tuple2<>(2, 0)); + writer1.forceSorterToSpill(); + writer1.insertRecordIntoSorter(new Tuple2<>(0, 1)); + writer1.insertRecordIntoSorter(new Tuple2<>(3, 0)); + writer1.forceSorterToSpill(); + writer1.insertRecordIntoSorter(new Tuple2<>(1, 1)); + writer1.forceSorterToSpill(); + writer1.insertRecordIntoSorter(new Tuple2<>(0, 2)); + writer1.forceSorterToSpill(); + writer1.closeAndWriteOutput(); + assertTrue(checksumFile.exists()); + assertEquals(checksumFile.length(), 8 * NUM_PARTITIONS); + compareChecksums(NUM_PARTITIONS, checksumFile, dataFile, indexFile); + } + private void testMergingSpills( final boolean transferToEnabled, String compressionCodecName, @@ -317,7 +385,7 @@ private void testMergingSpills( private void testMergingSpills( boolean transferToEnabled, - boolean encrypted) throws IOException { + boolean encrypted) throws IOException, SparkException { final UnsafeShuffleWriter writer = createWriter(transferToEnabled); final ArrayList> dataToWrite = new ArrayList<>(); for (int i : new int[] { 1, 2, 3, 4, 4, 2 }) { @@ -515,7 +583,7 @@ public void writeRecordsThatAreBiggerThanMaxRecordSize() throws Exception { } @Test - public void spillFilesAreDeletedWhenStoppingAfterError() throws IOException { + public void spillFilesAreDeletedWhenStoppingAfterError() throws IOException, SparkException { final UnsafeShuffleWriter writer = createWriter(false); writer.insertRecordIntoSorter(new Tuple2<>(1, 1)); writer.insertRecordIntoSorter(new Tuple2<>(2, 2)); diff --git a/core/src/test/scala/org/apache/spark/shuffle/ShuffleChecksumTestHelper.scala b/core/src/test/scala/org/apache/spark/shuffle/ShuffleChecksumTestHelper.scala new file mode 100644 index 0000000000000..a8f2c4088c422 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/shuffle/ShuffleChecksumTestHelper.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle + +import java.io.{DataInputStream, File, FileInputStream} +import java.util.zip.CheckedInputStream + +import org.apache.spark.network.util.LimitedInputStream +import org.apache.spark.shuffle.checksum.ShuffleChecksumHelper + +trait ShuffleChecksumTestHelper { + + /** + * Ensure that the checksum values are consistent between write and read side. + */ + def compareChecksums(numPartition: Int, checksum: File, data: File, index: File): Unit = { + assert(checksum.exists(), "Checksum file doesn't exist") + assert(data.exists(), "Data file doesn't exist") + assert(index.exists(), "Index file doesn't exist") + + var checksumIn: DataInputStream = null + val expectChecksums = Array.ofDim[Long](numPartition) + try { + checksumIn = new DataInputStream(new FileInputStream(checksum)) + (0 until numPartition).foreach(i => expectChecksums(i) = checksumIn.readLong()) + } finally { + if (checksumIn != null) { + checksumIn.close() + } + } + + var dataIn: FileInputStream = null + var indexIn: DataInputStream = null + var checkedIn: CheckedInputStream = null + try { + dataIn = new FileInputStream(data) + indexIn = new DataInputStream(new FileInputStream(index)) + var prevOffset = indexIn.readLong + (0 until numPartition).foreach { i => + val curOffset = indexIn.readLong + val limit = (curOffset - prevOffset).toInt + val bytes = new Array[Byte](limit) + val checksumCal = ShuffleChecksumHelper.getChecksumByFileExtension(checksum.getName) + checkedIn = new CheckedInputStream( + new LimitedInputStream(dataIn, curOffset - prevOffset), checksumCal) + checkedIn.read(bytes, 0, limit) + prevOffset = curOffset + // checksum must be consistent at both write and read sides + assert(checkedIn.getChecksum.getValue == expectChecksums(i)) + } + } finally { + if (dataIn != null) { + dataIn.close() + } + if (indexIn != null) { + indexIn.close() + } + if (checkedIn != null) { + checkedIn.close() + } + } + } +} diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala index 7fd0bf626fda1..39eef9749eac3 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala @@ -33,13 +33,17 @@ import org.apache.spark._ import org.apache.spark.executor.{ShuffleWriteMetrics, TaskMetrics} import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager} import org.apache.spark.serializer.{JavaSerializer, SerializerInstance, SerializerManager} -import org.apache.spark.shuffle.IndexShuffleBlockResolver +import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleChecksumTestHelper} import org.apache.spark.shuffle.api.ShuffleExecutorComponents +import org.apache.spark.shuffle.checksum.ShuffleChecksumHelper import org.apache.spark.shuffle.sort.io.LocalDiskShuffleExecutorComponents import org.apache.spark.storage._ import org.apache.spark.util.Utils -class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfterEach { +class BypassMergeSortShuffleWriterSuite + extends SparkFunSuite + with BeforeAndAfterEach + with ShuffleChecksumTestHelper { @Mock(answer = RETURNS_SMART_NULLS) private var blockManager: BlockManager = _ @Mock(answer = RETURNS_SMART_NULLS) private var diskBlockManager: DiskBlockManager = _ @@ -76,10 +80,10 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte when(blockManager.diskBlockManager).thenReturn(diskBlockManager) when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager) - when(blockResolver.writeIndexFileAndCommit( - anyInt, anyLong, any(classOf[Array[Long]]), any(classOf[File]))) + when(blockResolver.writeMetadataFileAndCommit( + anyInt, anyLong, any(classOf[Array[Long]]), any(classOf[Array[Long]]), any(classOf[File]))) .thenAnswer { invocationOnMock => - val tmp = invocationOnMock.getArguments()(3).asInstanceOf[File] + val tmp = invocationOnMock.getArguments()(4).asInstanceOf[File] if (tmp != null) { outputFile.delete tmp.renameTo(outputFile) @@ -236,4 +240,43 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte writer.stop( /* success = */ false) assert(temporaryFilesCreated.count(_.exists()) === 0) } + + test("write checksum file") { + val blockResolver = new IndexShuffleBlockResolver(conf, blockManager) + val shuffleId = shuffleHandle.shuffleId + val mapId = 0 + val checksumBlockId = ShuffleChecksumBlockId(shuffleId, mapId, 0) + val dataBlockId = ShuffleDataBlockId(shuffleId, mapId, 0) + val indexBlockId = ShuffleIndexBlockId(shuffleId, mapId, 0) + val checksumFile = new File(tempDir, + ShuffleChecksumHelper.getChecksumFileName(checksumBlockId, conf)) + val dataFile = new File(tempDir, dataBlockId.name) + val indexFile = new File(tempDir, indexBlockId.name) + reset(diskBlockManager) + when(diskBlockManager.getFile(checksumFile.getName)).thenAnswer(_ => checksumFile) + when(diskBlockManager.getFile(dataBlockId)).thenAnswer(_ => dataFile) + when(diskBlockManager.getFile(indexBlockId)).thenAnswer(_ => indexFile) + when(diskBlockManager.createTempShuffleBlock()) + .thenAnswer { _ => + val blockId = new TempShuffleBlockId(UUID.randomUUID) + val file = new File(tempDir, blockId.name) + temporaryFilesCreated += file + (blockId, file) + } + + val numPartition = shuffleHandle.dependency.partitioner.numPartitions + val writer = new BypassMergeSortShuffleWriter[Int, Int]( + blockManager, + shuffleHandle, + mapId, + conf, + taskContext.taskMetrics().shuffleWriteMetrics, + new LocalDiskShuffleExecutorComponents(conf, blockManager, blockResolver)) + + writer.write(Iterator((0, 0), (1, 1), (2, 2), (3, 3), (4, 4), (5, 5), (6, 6))) + writer.stop( /* success = */ true) + assert(checksumFile.exists()) + assert(checksumFile.length() === 8 * numPartition) + compareChecksums(numPartition, checksumFile, dataFile, indexFile) + } } diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala index 5955d442a77fe..49c079cd4fce9 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala @@ -28,6 +28,7 @@ import org.roaringbitmap.RoaringBitmap import org.scalatest.BeforeAndAfterEach import org.apache.spark.{MapOutputTracker, SparkConf, SparkFunSuite} +import org.apache.spark.internal.config import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleBlockInfo} import org.apache.spark.storage._ import org.apache.spark.util.Utils @@ -49,6 +50,8 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa when(blockManager.diskBlockManager).thenReturn(diskBlockManager) when(diskBlockManager.getFile(any[BlockId])).thenAnswer( (invocation: InvocationOnMock) => new File(tempDir, invocation.getArguments.head.toString)) + when(diskBlockManager.getFile(any[String])).thenAnswer( + (invocation: InvocationOnMock) => new File(tempDir, invocation.getArguments.head.toString)) when(diskBlockManager.getMergedShuffleFile( any[BlockId], any[Option[Array[String]]])).thenAnswer( (invocation: InvocationOnMock) => new File(tempDir, invocation.getArguments.head.toString)) @@ -77,7 +80,7 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa } { out.close() } - resolver.writeIndexFileAndCommit(shuffleId, mapId, lengths, dataTmp) + resolver.writeMetadataFileAndCommit(shuffleId, mapId, lengths, Array.empty, dataTmp) val indexFile = new File(tempDir.getAbsolutePath, idxName) val dataFile = resolver.getDataFile(shuffleId, mapId) @@ -97,7 +100,7 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa } { out2.close() } - resolver.writeIndexFileAndCommit(shuffleId, mapId, lengths2, dataTmp2) + resolver.writeMetadataFileAndCommit(shuffleId, mapId, lengths2, Array.empty, dataTmp2) assert(indexFile.length() === (lengths.length + 1) * 8) assert(lengths2.toSeq === lengths.toSeq) @@ -136,7 +139,7 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa } { out3.close() } - resolver.writeIndexFileAndCommit(shuffleId, mapId, lengths3, dataTmp3) + resolver.writeMetadataFileAndCommit(shuffleId, mapId, lengths3, Array.empty, dataTmp3) assert(indexFile.length() === (lengths3.length + 1) * 8) assert(lengths3.toSeq != lengths.toSeq) assert(dataFile.exists()) @@ -248,4 +251,19 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa outIndex.close() } } + + test("write checksum file") { + val resolver = new IndexShuffleBlockResolver(conf, blockManager) + val dataTmp = File.createTempFile("shuffle", null, tempDir) + val indexInMemory = Array[Long](0, 1, 2, 3, 4, 5, 6, 7, 8, 9) + val checksumsInMemory = Array[Long](0, 1, 2, 3, 4, 5, 6, 7, 8, 9) + resolver.writeMetadataFileAndCommit(0, 0, indexInMemory, checksumsInMemory, dataTmp) + val checksumFile = resolver.getChecksumFile(0, 0) + assert(checksumFile.exists()) + val checksumFileName = checksumFile.toString + val checksumAlgo = checksumFileName.substring(checksumFileName.lastIndexOf(".") + 1) + assert(checksumAlgo === conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM)) + val checksumsFromFile = resolver.getChecksums(checksumFile, 10) + assert(checksumsInMemory === checksumsFromFile) + } } diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala index 4c679fd874c9b..e3457367d9baf 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala @@ -20,19 +20,25 @@ package org.apache.spark.shuffle.sort import org.mockito.{Mock, MockitoAnnotations} import org.mockito.Answers.RETURNS_SMART_NULLS import org.mockito.Mockito._ +import org.scalatest.PrivateMethodTester import org.scalatest.matchers.must.Matchers -import org.apache.spark.{Partitioner, SharedSparkContext, ShuffleDependency, SparkFunSuite} +import org.apache.spark.{Aggregator, DebugFilesystem, Partitioner, SharedSparkContext, ShuffleDependency, SparkContext, SparkFunSuite} import org.apache.spark.memory.MemoryTestingUtils import org.apache.spark.serializer.JavaSerializer -import org.apache.spark.shuffle.{BaseShuffleHandle, IndexShuffleBlockResolver} +import org.apache.spark.shuffle.{BaseShuffleHandle, IndexShuffleBlockResolver, ShuffleChecksumTestHelper} import org.apache.spark.shuffle.api.ShuffleExecutorComponents import org.apache.spark.shuffle.sort.io.LocalDiskShuffleExecutorComponents import org.apache.spark.storage.BlockManager import org.apache.spark.util.Utils +import org.apache.spark.util.collection.ExternalSorter - -class SortShuffleWriterSuite extends SparkFunSuite with SharedSparkContext with Matchers { +class SortShuffleWriterSuite + extends SparkFunSuite + with SharedSparkContext + with Matchers + with PrivateMethodTester + with ShuffleChecksumTestHelper { @Mock(answer = RETURNS_SMART_NULLS) private var blockManager: BlockManager = _ @@ -44,13 +50,14 @@ class SortShuffleWriterSuite extends SparkFunSuite with SharedSparkContext with private val serializer = new JavaSerializer(conf) private var shuffleExecutorComponents: ShuffleExecutorComponents = _ + private val partitioner = new Partitioner() { + def numPartitions = numMaps + def getPartition(key: Any) = Utils.nonNegativeMod(key.hashCode, numPartitions) + } + override def beforeEach(): Unit = { super.beforeEach() MockitoAnnotations.openMocks(this).close() - val partitioner = new Partitioner() { - def numPartitions = numMaps - def getPartition(key: Any) = Utils.nonNegativeMod(key.hashCode, numPartitions) - } shuffleHandle = { val dependency = mock(classOf[ShuffleDependency[Int, Int, Int]]) when(dependency.partitioner).thenReturn(partitioner) @@ -103,4 +110,68 @@ class SortShuffleWriterSuite extends SparkFunSuite with SharedSparkContext with assert(dataFile.length() === writeMetrics.bytesWritten) assert(records.size === writeMetrics.recordsWritten) } + + Seq((true, false, false), + (true, true, false), + (true, false, true), + (true, true, true), + (false, false, false), + (false, true, false), + (false, false, true), + (false, true, true)).foreach { case (doSpill, doAgg, doOrder) => + test(s"write checksum file (spill=$doSpill, aggregator=$doAgg, order=$doOrder)") { + val aggregator = if (doAgg) { + Some(Aggregator[Int, Int, Int]( + v => v, + (c, v) => c + v, + (c1, c2) => c1 + c2)) + } else None + val order = if (doOrder) { + Some(new Ordering[Int] { + override def compare(x: Int, y: Int): Int = x - y + }) + } else None + + val shuffleHandle = { + val dependency = mock(classOf[ShuffleDependency[Int, Int, Int]]) + when(dependency.partitioner).thenReturn(partitioner) + when(dependency.serializer).thenReturn(serializer) + when(dependency.aggregator).thenReturn(aggregator) + when(dependency.keyOrdering).thenReturn(order) + new BaseShuffleHandle[Int, Int, Int](shuffleId, dependency) + } + + // FIXME: this can affect other tests (if any) after this set of tests + // since `sc` is global. + sc.stop() + conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", + if (doSpill) "0" else Int.MaxValue.toString) + conf.set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName) + val localSC = new SparkContext("local[4]", "test", conf) + val shuffleBlockResolver = new IndexShuffleBlockResolver(conf) + val context = MemoryTestingUtils.fakeTaskContext(localSC.env) + val records = List[(Int, Int)]( + (0, 1), (1, 2), (0, 2), (1, 3), (2, 3), (3, 4), (4, 5), (3, 5), (4, 6)) + val numPartition = shuffleHandle.dependency.partitioner.numPartitions + val writer = new SortShuffleWriter[Int, Int, Int]( + shuffleHandle, + mapId = 0, + context, + new LocalDiskShuffleExecutorComponents( + conf, shuffleBlockResolver._blockManager, shuffleBlockResolver)) + writer.write(records.toIterator) + val sorterMethod = PrivateMethod[ExternalSorter[_, _, _]](Symbol("sorter")) + val sorter = writer.invokePrivate(sorterMethod()) + val expectSpillSize = if (doSpill) records.size else 0 + assert(sorter.numSpills === expectSpillSize) + writer.stop(success = true) + val checksumFile = shuffleBlockResolver.getChecksumFile(shuffleId, 0) + assert(checksumFile.exists()) + assert(checksumFile.length() === 8 * numPartition) + val dataFile = shuffleBlockResolver.getDataFile(shuffleId, 0) + val indexFile = shuffleBlockResolver.getIndexFile(shuffleId, 0) + compareChecksums(numPartition, checksumFile, dataFile, indexFile) + localSC.stop() + } + } } diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala index ef5c615bf7591..35d9b4ab1f766 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala @@ -74,11 +74,11 @@ class LocalDiskShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndA .set("spark.app.id", "example.spark.app") .set("spark.shuffle.unsafe.file.output.buffer", "16k") when(blockResolver.getDataFile(anyInt, anyLong)).thenReturn(mergedOutputFile) - when(blockResolver.writeIndexFileAndCommit( - anyInt, anyLong, any(classOf[Array[Long]]), any(classOf[File]))) + when(blockResolver.writeMetadataFileAndCommit( + anyInt, anyLong, any(classOf[Array[Long]]), any(classOf[Array[Long]]), any(classOf[File]))) .thenAnswer { invocationOnMock => partitionSizesInMergedFile = invocationOnMock.getArguments()(2).asInstanceOf[Array[Long]] - val tmp: File = invocationOnMock.getArguments()(3).asInstanceOf[File] + val tmp: File = invocationOnMock.getArguments()(4).asInstanceOf[File] if (tmp != null) { mergedOutputFile.delete() tmp.renameTo(mergedOutputFile) @@ -136,7 +136,8 @@ class LocalDiskShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndA } private def verifyWrittenRecords(): Unit = { - val committedLengths = mapOutputWriter.commitAllPartitions().getPartitionLengths + val committedLengths = + mapOutputWriter.commitAllPartitions(Array.empty[Long]).getPartitionLengths assert(partitionSizesInMergedFile === partitionLengths) assert(committedLengths === partitionLengths) assert(mergedOutputFile.length() === partitionLengths.sum) diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala index a6de64b6c68a0..7bec961218759 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala @@ -500,15 +500,15 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { intercept[SparkException] { data.reduceByKey(_ + _).count() } - // After the shuffle, there should be only 2 files on disk: the output of task 1 and - // its index. All other files (map 2's output and intermediate merge files) should - // have been deleted. - assert(diskBlockManager.getAllFiles().length === 2) + // After the shuffle, there should be only 3 files on disk: the output of task 1 and + // its index and checksum. All other files (map 2's output and intermediate merge files) + // should have been deleted. + assert(diskBlockManager.getAllFiles().length === 3) } else { assert(data.reduceByKey(_ + _).count() === size) - // After the shuffle, there should be only 4 files on disk: the output of both tasks - // and their indices. All intermediate merge files should have been deleted. - assert(diskBlockManager.getAllFiles().length === 4) + // After the shuffle, there should be only 6 files on disk: the output of both tasks + // and their indices/checksums. All intermediate merge files should have been deleted. + assert(diskBlockManager.getAllFiles().length === 6) } } } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 032dd19fd594b..04694724ca381 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -95,7 +95,15 @@ object MimaExcludes { ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.util.collection.WritablePartitionedIterator"), // [SPARK-35757][CORE] Add bitwise AND operation and functionality for intersecting bloom filters - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.util.sketch.BloomFilter.intersectInPlace") + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.util.sketch.BloomFilter.intersectInPlace"), + + // [SPARK-35276][CORE] Calculate checksum for shuffle data and write as checksum file + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.shuffle.sort.io.LocalDiskShuffleMapOutputWriter.commitAllPartitions"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.shuffle.sort.io.LocalDiskSingleSpillMapOutputWriter.transferMapSpillFile"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.shuffle.api.ShuffleMapOutputWriter.commitAllPartitions"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.shuffle.api.SingleSpillShuffleMapOutputWriter.transferMapSpillFile"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.shuffle.api.SingleSpillShuffleMapOutputWriter.transferMapSpillFile"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.shuffle.api.ShuffleMapOutputWriter.commitAllPartitions") ) def excludes(version: String) = version match {