diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java
index 0167002ceedb8..2237ec052cac2 100644
--- a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java
@@ -59,6 +59,10 @@ public interface ShuffleMapOutputWriter {
* available to downstream reduce tasks. If this method throws any exception, this module's
* {@link #abort(Throwable)} method will be invoked before propagating the exception.
*
+ * Shuffle extensions which care about the cause of shuffle data corruption should store
+ * the checksums properly. When corruption happens, Spark would provide the checksum
+ * of the fetched partition to the shuffle extension to help diagnose the cause of corruption.
+ *
* This can also close any resources and clean up temporary state if necessary.
*
* The returned commit message is a structure with two components:
@@ -68,8 +72,11 @@ public interface ShuffleMapOutputWriter {
* for that partition id.
*
* 2) An optional metadata blob that can be used by shuffle readers.
+ *
+ * @param checksums The checksum values for each partition (where checksum index is equivalent to
+ * partition id) if shuffle checksum enabled. Otherwise, it's empty.
*/
- MapOutputCommitMessage commitAllPartitions() throws IOException;
+ MapOutputCommitMessage commitAllPartitions(long[] checksums) throws IOException;
/**
* Abort all of the writes done by any writers returned by {@link #getPartitionWriter(int)}.
diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionWriter.java b/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionWriter.java
index 928875156a70f..143cc6c871e5f 100644
--- a/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionWriter.java
@@ -49,7 +49,7 @@ public interface ShufflePartitionWriter {
* by the parent {@link ShuffleMapOutputWriter}. If one does so, ensure that
* {@link OutputStream#close()} does not close the resource, since it will be reused across
* partition writes. The underlying resources should be cleaned up in
- * {@link ShuffleMapOutputWriter#commitAllPartitions()} and
+ * {@link ShuffleMapOutputWriter#commitAllPartitions(long[])} and
* {@link ShuffleMapOutputWriter#abort(Throwable)}.
*/
OutputStream openStream() throws IOException;
@@ -68,7 +68,7 @@ public interface ShufflePartitionWriter {
* by the parent {@link ShuffleMapOutputWriter}. If one does so, ensure that
* {@link WritableByteChannelWrapper#close()} does not close the resource, since the channel
* will be reused across partition writes. The underlying resources should be cleaned up in
- * {@link ShuffleMapOutputWriter#commitAllPartitions()} and
+ * {@link ShuffleMapOutputWriter#commitAllPartitions(long[])} and
* {@link ShuffleMapOutputWriter#abort(Throwable)}.
*
* This method is primarily for advanced optimizations where bytes can be copied from the input
@@ -79,7 +79,7 @@ public interface ShufflePartitionWriter {
* Note that the returned {@link WritableByteChannelWrapper} itself is closed, but not the
* underlying channel that is returned by {@link WritableByteChannelWrapper#channel()}. Ensure
* that the underlying channel is cleaned up in {@link WritableByteChannelWrapper#close()},
- * {@link ShuffleMapOutputWriter#commitAllPartitions()}, or
+ * {@link ShuffleMapOutputWriter#commitAllPartitions(long[])}, or
* {@link ShuffleMapOutputWriter#abort(Throwable)}.
*/
default Optional openChannelWrapper() throws IOException {
diff --git a/core/src/main/java/org/apache/spark/shuffle/api/SingleSpillShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/api/SingleSpillShuffleMapOutputWriter.java
index cad8dcfda52bc..ba3d5a603e052 100644
--- a/core/src/main/java/org/apache/spark/shuffle/api/SingleSpillShuffleMapOutputWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/api/SingleSpillShuffleMapOutputWriter.java
@@ -32,5 +32,8 @@ public interface SingleSpillShuffleMapOutputWriter {
/**
* Transfer a file that contains the bytes of all the partitions written by this map task.
*/
- void transferMapSpillFile(File mapOutputFile, long[] partitionLengths) throws IOException;
+ void transferMapSpillFile(
+ File mapOutputFile,
+ long[] partitionLengths,
+ long[] checksums) throws IOException;
}
diff --git a/core/src/main/java/org/apache/spark/shuffle/checksum/ShuffleChecksumHelper.java b/core/src/main/java/org/apache/spark/shuffle/checksum/ShuffleChecksumHelper.java
new file mode 100644
index 0000000000000..a368836d2bb1d
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/checksum/ShuffleChecksumHelper.java
@@ -0,0 +1,100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.checksum;
+
+import java.util.zip.Adler32;
+import java.util.zip.CRC32;
+import java.util.zip.Checksum;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.SparkException;
+import org.apache.spark.annotation.Private;
+import org.apache.spark.internal.config.package$;
+import org.apache.spark.storage.ShuffleChecksumBlockId;
+
+/**
+ * A set of utility functions for the shuffle checksum.
+ */
+@Private
+public class ShuffleChecksumHelper {
+
+ /** Used when the checksum is disabled for shuffle. */
+ private static final Checksum[] EMPTY_CHECKSUM = new Checksum[0];
+ public static final long[] EMPTY_CHECKSUM_VALUE = new long[0];
+
+ public static boolean isShuffleChecksumEnabled(SparkConf conf) {
+ return (boolean) conf.get(package$.MODULE$.SHUFFLE_CHECKSUM_ENABLED());
+ }
+
+ public static Checksum[] createPartitionChecksumsIfEnabled(int numPartitions, SparkConf conf)
+ throws SparkException {
+ if (!isShuffleChecksumEnabled(conf)) {
+ return EMPTY_CHECKSUM;
+ }
+
+ String checksumAlgo = shuffleChecksumAlgorithm(conf);
+ return getChecksumByAlgorithm(numPartitions, checksumAlgo);
+ }
+
+ private static Checksum[] getChecksumByAlgorithm(int num, String algorithm)
+ throws SparkException {
+ Checksum[] checksums;
+ switch (algorithm) {
+ case "ADLER32":
+ checksums = new Adler32[num];
+ for (int i = 0; i < num; i ++) {
+ checksums[i] = new Adler32();
+ }
+ return checksums;
+
+ case "CRC32":
+ checksums = new CRC32[num];
+ for (int i = 0; i < num; i ++) {
+ checksums[i] = new CRC32();
+ }
+ return checksums;
+
+ default:
+ throw new SparkException("Unsupported shuffle checksum algorithm: " + algorithm);
+ }
+ }
+
+ public static long[] getChecksumValues(Checksum[] partitionChecksums) {
+ int numPartitions = partitionChecksums.length;
+ long[] checksumValues = new long[numPartitions];
+ for (int i = 0; i < numPartitions; i ++) {
+ checksumValues[i] = partitionChecksums[i].getValue();
+ }
+ return checksumValues;
+ }
+
+ public static String shuffleChecksumAlgorithm(SparkConf conf) {
+ return conf.get(package$.MODULE$.SHUFFLE_CHECKSUM_ALGORITHM());
+ }
+
+ public static Checksum getChecksumByFileExtension(String fileName) throws SparkException {
+ int index = fileName.lastIndexOf(".");
+ String algorithm = fileName.substring(index + 1);
+ return getChecksumByAlgorithm(1, algorithm)[0];
+ }
+
+ public static String getChecksumFileName(ShuffleChecksumBlockId blockId, SparkConf conf) {
+ // append the shuffle checksum algorithm as the file extension
+ return String.format("%s.%s", blockId.name(), shuffleChecksumAlgorithm(conf));
+ }
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
index 3dbee1b13d287..322224053df09 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
@@ -23,6 +23,7 @@
import java.io.OutputStream;
import java.nio.channels.FileChannel;
import java.util.Optional;
+import java.util.zip.Checksum;
import javax.annotation.Nullable;
import scala.None$;
@@ -38,6 +39,7 @@
import org.apache.spark.Partitioner;
import org.apache.spark.ShuffleDependency;
import org.apache.spark.SparkConf;
+import org.apache.spark.SparkException;
import org.apache.spark.shuffle.api.ShuffleExecutorComponents;
import org.apache.spark.shuffle.api.ShuffleMapOutputWriter;
import org.apache.spark.shuffle.api.ShufflePartitionWriter;
@@ -49,6 +51,7 @@
import org.apache.spark.serializer.SerializerInstance;
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
import org.apache.spark.shuffle.ShuffleWriter;
+import org.apache.spark.shuffle.checksum.ShuffleChecksumHelper;
import org.apache.spark.storage.*;
import org.apache.spark.util.Utils;
@@ -93,6 +96,8 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter {
private FileSegment[] partitionWriterSegments;
@Nullable private MapStatus mapStatus;
private long[] partitionLengths;
+ /** Checksum calculator for each partition. Empty when shuffle checksum disabled. */
+ private final Checksum[] partitionChecksums;
/**
* Are we in the process of stopping? Because map tasks can call stop() with success = true
@@ -107,7 +112,7 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter {
long mapId,
SparkConf conf,
ShuffleWriteMetricsReporter writeMetrics,
- ShuffleExecutorComponents shuffleExecutorComponents) {
+ ShuffleExecutorComponents shuffleExecutorComponents) throws SparkException {
// Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
this.fileBufferSize = (int) (long) conf.get(package$.MODULE$.SHUFFLE_FILE_BUFFER_SIZE()) * 1024;
this.transferToEnabled = conf.getBoolean("spark.file.transferTo", true);
@@ -120,6 +125,8 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter {
this.writeMetrics = writeMetrics;
this.serializer = dep.serializer();
this.shuffleExecutorComponents = shuffleExecutorComponents;
+ this.partitionChecksums =
+ ShuffleChecksumHelper.createPartitionChecksumsIfEnabled(numPartitions, conf);
}
@Override
@@ -129,7 +136,8 @@ public void write(Iterator> records) throws IOException {
.createMapOutputWriter(shuffleId, mapId, numPartitions);
try {
if (!records.hasNext()) {
- partitionLengths = mapOutputWriter.commitAllPartitions().getPartitionLengths();
+ partitionLengths = mapOutputWriter.commitAllPartitions(
+ ShuffleChecksumHelper.EMPTY_CHECKSUM_VALUE).getPartitionLengths();
mapStatus = MapStatus$.MODULE$.apply(
blockManager.shuffleServerId(), partitionLengths, mapId);
return;
@@ -143,8 +151,12 @@ public void write(Iterator> records) throws IOException {
blockManager.diskBlockManager().createTempShuffleBlock();
final File file = tempShuffleBlockIdPlusFile._2();
final BlockId blockId = tempShuffleBlockIdPlusFile._1();
- partitionWriters[i] =
- blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, writeMetrics);
+ DiskBlockObjectWriter writer =
+ blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, writeMetrics);
+ if (partitionChecksums.length > 0) {
+ writer.setChecksum(partitionChecksums[i]);
+ }
+ partitionWriters[i] = writer;
}
// Creating the file to write to and creating a disk writer both involve interacting with
// the disk, and can take a long time in aggregate when we open many files, so should be
@@ -218,7 +230,9 @@ private long[] writePartitionedData(ShuffleMapOutputWriter mapOutputWriter) thro
}
partitionWriters = null;
}
- return mapOutputWriter.commitAllPartitions().getPartitionLengths();
+ return mapOutputWriter.commitAllPartitions(
+ ShuffleChecksumHelper.getChecksumValues(partitionChecksums)
+ ).getPartitionLengths();
}
private void writePartitionedDataWithChannel(
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
index 833744f4777ce..0307027c6f264 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
@@ -21,7 +21,9 @@
import java.io.File;
import java.io.IOException;
import java.util.LinkedList;
+import java.util.zip.Checksum;
+import org.apache.spark.SparkException;
import scala.Tuple2;
import com.google.common.annotations.VisibleForTesting;
@@ -39,6 +41,7 @@
import org.apache.spark.serializer.DummySerializerInstance;
import org.apache.spark.serializer.SerializerInstance;
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
+import org.apache.spark.shuffle.checksum.ShuffleChecksumHelper;
import org.apache.spark.storage.BlockManager;
import org.apache.spark.storage.DiskBlockObjectWriter;
import org.apache.spark.storage.FileSegment;
@@ -107,6 +110,9 @@ final class ShuffleExternalSorter extends MemoryConsumer {
@Nullable private MemoryBlock currentPage = null;
private long pageCursor = -1;
+ // Checksum calculator for each partition. Empty when shuffle checksum disabled.
+ private final Checksum[] partitionChecksums;
+
ShuffleExternalSorter(
TaskMemoryManager memoryManager,
BlockManager blockManager,
@@ -114,7 +120,7 @@ final class ShuffleExternalSorter extends MemoryConsumer {
int initialSize,
int numPartitions,
SparkConf conf,
- ShuffleWriteMetricsReporter writeMetrics) {
+ ShuffleWriteMetricsReporter writeMetrics) throws SparkException {
super(memoryManager,
(int) Math.min(PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES, memoryManager.pageSizeBytes()),
memoryManager.getTungstenMemoryMode());
@@ -133,6 +139,12 @@ final class ShuffleExternalSorter extends MemoryConsumer {
this.peakMemoryUsedBytes = getMemoryUsage();
this.diskWriteBufferSize =
(int) (long) conf.get(package$.MODULE$.SHUFFLE_DISK_WRITE_BUFFER_SIZE());
+ this.partitionChecksums =
+ ShuffleChecksumHelper.createPartitionChecksumsIfEnabled(numPartitions, conf);
+ }
+
+ public long[] getChecksums() {
+ return ShuffleChecksumHelper.getChecksumValues(partitionChecksums);
}
/**
@@ -204,6 +216,9 @@ private void writeSortedFile(boolean isLastFile) {
spillInfo.partitionLengths[currentPartition] = fileSegment.length();
}
currentPartition = partition;
+ if (partitionChecksums.length > 0) {
+ writer.setChecksum(partitionChecksums[currentPartition]);
+ }
}
final long recordPointer = sortedRecords.packedRecordPointer.getRecordPointer();
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
index e8f94ba8ffeee..2659b172bf68c 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
@@ -57,6 +57,7 @@
import org.apache.spark.shuffle.api.ShufflePartitionWriter;
import org.apache.spark.shuffle.api.SingleSpillShuffleMapOutputWriter;
import org.apache.spark.shuffle.api.WritableByteChannelWrapper;
+import org.apache.spark.shuffle.checksum.ShuffleChecksumHelper;
import org.apache.spark.storage.BlockManager;
import org.apache.spark.storage.TimeTrackingOutputStream;
import org.apache.spark.unsafe.Platform;
@@ -115,7 +116,7 @@ public UnsafeShuffleWriter(
TaskContext taskContext,
SparkConf sparkConf,
ShuffleWriteMetricsReporter writeMetrics,
- ShuffleExecutorComponents shuffleExecutorComponents) {
+ ShuffleExecutorComponents shuffleExecutorComponents) throws SparkException {
final int numPartitions = handle.dependency().partitioner().numPartitions();
if (numPartitions > SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE()) {
throw new IllegalArgumentException(
@@ -198,7 +199,7 @@ public void write(scala.collection.Iterator> records) throws IOEx
}
}
- private void open() {
+ private void open() throws SparkException {
assert (sorter == null);
sorter = new ShuffleExternalSorter(
memoryManager,
@@ -219,10 +220,10 @@ void closeAndWriteOutput() throws IOException {
serBuffer = null;
serOutputStream = null;
final SpillInfo[] spills = sorter.closeAndGetSpills();
- sorter = null;
try {
partitionLengths = mergeSpills(spills);
} finally {
+ sorter = null;
for (SpillInfo spill : spills) {
if (spill.file.exists() && !spill.file.delete()) {
logger.error("Error while deleting spill file {}", spill.file.getPath());
@@ -267,7 +268,8 @@ private long[] mergeSpills(SpillInfo[] spills) throws IOException {
if (spills.length == 0) {
final ShuffleMapOutputWriter mapWriter = shuffleExecutorComponents
.createMapOutputWriter(shuffleId, mapId, partitioner.numPartitions());
- return mapWriter.commitAllPartitions().getPartitionLengths();
+ return mapWriter.commitAllPartitions(
+ ShuffleChecksumHelper.EMPTY_CHECKSUM_VALUE).getPartitionLengths();
} else if (spills.length == 1) {
Optional maybeSingleFileWriter =
shuffleExecutorComponents.createSingleFileMapOutputWriter(shuffleId, mapId);
@@ -277,7 +279,8 @@ private long[] mergeSpills(SpillInfo[] spills) throws IOException {
partitionLengths = spills[0].partitionLengths;
logger.debug("Merge shuffle spills for mapId {} with length {}", mapId,
partitionLengths.length);
- maybeSingleFileWriter.get().transferMapSpillFile(spills[0].file, partitionLengths);
+ maybeSingleFileWriter.get()
+ .transferMapSpillFile(spills[0].file, partitionLengths, sorter.getChecksums());
} else {
partitionLengths = mergeSpillsUsingStandardWriter(spills);
}
@@ -330,7 +333,7 @@ private long[] mergeSpillsUsingStandardWriter(SpillInfo[] spills) throws IOExcep
// to be counted as shuffle write, but this will lead to double-counting of the final
// SpillInfo's bytes.
writeMetrics.decBytesWritten(spills[spills.length - 1].file.length());
- partitionLengths = mapWriter.commitAllPartitions().getPartitionLengths();
+ partitionLengths = mapWriter.commitAllPartitions(sorter.getChecksums()).getPartitionLengths();
} catch (Exception e) {
try {
mapWriter.abort(e);
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java
index 0b286264be43d..6c5025d1822f8 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java
@@ -98,7 +98,7 @@ public ShufflePartitionWriter getPartitionWriter(int reducePartitionId) throws I
}
@Override
- public MapOutputCommitMessage commitAllPartitions() throws IOException {
+ public MapOutputCommitMessage commitAllPartitions(long[] checksums) throws IOException {
// Check the position after transferTo loop to see if it is in the right position and raise a
// exception if it is incorrect. The position will not be increased to the expected length
// after calling transferTo in kernel version 2.6.32. This issue is described at
@@ -115,7 +115,8 @@ public MapOutputCommitMessage commitAllPartitions() throws IOException {
File resolvedTmp = outputTempFile != null && outputTempFile.isFile() ? outputTempFile : null;
log.debug("Writing shuffle index file for mapId {} with length {}", mapId,
partitionLengths.length);
- blockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, resolvedTmp);
+ blockResolver
+ .writeMetadataFileAndCommit(shuffleId, mapId, partitionLengths, checksums, resolvedTmp);
return MapOutputCommitMessage.of(partitionLengths);
}
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskSingleSpillMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskSingleSpillMapOutputWriter.java
index c8b41992a8919..6a994b49d3a29 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskSingleSpillMapOutputWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskSingleSpillMapOutputWriter.java
@@ -44,12 +44,14 @@ public LocalDiskSingleSpillMapOutputWriter(
@Override
public void transferMapSpillFile(
File mapSpillFile,
- long[] partitionLengths) throws IOException {
+ long[] partitionLengths,
+ long[] checksums) throws IOException {
// The map spill file already has the proper format, and it contains all of the partition data.
// So just transfer it directly to the destination without any merging.
File outputFile = blockResolver.getDataFile(shuffleId, mapId);
File tempFile = Utils.tempFileWith(outputFile);
Files.move(mapSpillFile.toPath(), tempFile.toPath());
- blockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tempFile);
+ blockResolver
+ .writeMetadataFileAndCommit(shuffleId, mapId, partitionLengths, checksums, tempFile);
}
}
diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala
index 2a8354a528806..1e9e7c9bece77 100644
--- a/core/src/main/scala/org/apache/spark/internal/config/package.scala
+++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala
@@ -1370,6 +1370,25 @@ package object config {
s"The buffer size must be greater than 0 and less than or equal to ${Int.MaxValue}.")
.createWithDefault(4096)
+ private[spark] val SHUFFLE_CHECKSUM_ENABLED =
+ ConfigBuilder("spark.shuffle.checksum.enabled")
+ .doc("Whether to calculate the checksum of shuffle output. If enabled, Spark will try " +
+ "its best to tell if shuffle data corruption is caused by network or disk or others.")
+ .version("3.3.0")
+ .booleanConf
+ .createWithDefault(true)
+
+ private[spark] val SHUFFLE_CHECKSUM_ALGORITHM =
+ ConfigBuilder("spark.shuffle.checksum.algorithm")
+ .doc("The algorithm used to calculate the checksum. Currently, it only supports" +
+ " built-in algorithms of JDK.")
+ .version("3.3.0")
+ .stringConf
+ .transform(_.toUpperCase(Locale.ROOT))
+ .checkValue(Set("ADLER32", "CRC32").contains, "Shuffle checksum algorithm " +
+ "should be either Adler32 or CRC32.")
+ .createWithDefault("ADLER32")
+
private[spark] val SHUFFLE_COMPRESS =
ConfigBuilder("spark.shuffle.compress")
.doc("Whether to compress shuffle output. Compression will use " +
diff --git a/core/src/main/scala/org/apache/spark/io/MutableCheckedOutputStream.scala b/core/src/main/scala/org/apache/spark/io/MutableCheckedOutputStream.scala
new file mode 100644
index 0000000000000..754b4a87720a9
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/io/MutableCheckedOutputStream.scala
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.io
+
+import java.io.OutputStream
+import java.util.zip.Checksum
+
+/**
+ * A variant of [[java.util.zip.CheckedOutputStream]] which can
+ * change the checksum calculator at runtime.
+ */
+class MutableCheckedOutputStream(out: OutputStream) extends OutputStream {
+ private var checksum: Checksum = _
+
+ def setChecksum(c: Checksum): Unit = {
+ this.checksum = c
+ }
+
+ override def write(b: Int): Unit = {
+ assert(checksum != null, "Checksum is not set.")
+ checksum.update(b)
+ out.write(b)
+ }
+
+ override def write(b: Array[Byte], off: Int, len: Int): Unit = {
+ assert(checksum != null, "Checksum is not set.")
+ checksum.update(b, off, len)
+ out.write(b, off, len)
+ }
+
+ override def flush(): Unit = out.flush()
+
+ override def close(): Unit = out.close()
+}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
index 5b1f93359c6ff..9c50569c78924 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
@@ -22,8 +22,10 @@ import java.nio.ByteBuffer
import java.nio.channels.Channels
import java.nio.file.Files
-import org.apache.spark.{SparkConf, SparkEnv}
-import org.apache.spark.internal.Logging
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.{SparkConf, SparkEnv, SparkException}
+import org.apache.spark.internal.{config, Logging}
import org.apache.spark.io.NioBufferedFileInputStream
import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
import org.apache.spark.network.client.StreamCallbackWithID
@@ -31,6 +33,7 @@ import org.apache.spark.network.netty.SparkTransportConf
import org.apache.spark.network.shuffle.{ExecutorDiskUtils, MergedBlockMeta}
import org.apache.spark.serializer.SerializerManager
import org.apache.spark.shuffle.IndexShuffleBlockResolver.NOOP_REDUCE_ID
+import org.apache.spark.shuffle.checksum.ShuffleChecksumHelper
import org.apache.spark.storage._
import org.apache.spark.util.Utils
@@ -56,6 +59,8 @@ private[spark] class IndexShuffleBlockResolver(
private val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle")
+ private val remoteShuffleMaxDisk: Option[Long] =
+ conf.get(config.STORAGE_DECOMMISSION_SHUFFLE_MAX_DISK_SIZE)
def getDataFile(shuffleId: Int, mapId: Long): File = getDataFile(shuffleId, mapId, None)
@@ -72,6 +77,13 @@ private[spark] class IndexShuffleBlockResolver(
}
}
+ private def getShuffleBytesStored(): Long = {
+ val shuffleFiles: Seq[File] = getStoredShuffles().map {
+ si => getDataFile(si.shuffleId, si.mapId)
+ }
+ shuffleFiles.map(_.length()).sum
+ }
+
/**
* Get the shuffle data file.
*
@@ -133,17 +145,18 @@ private[spark] class IndexShuffleBlockResolver(
*/
def removeDataByMap(shuffleId: Int, mapId: Long): Unit = {
var file = getDataFile(shuffleId, mapId)
- if (file.exists()) {
- if (!file.delete()) {
- logWarning(s"Error deleting data ${file.getPath()}")
- }
+ if (file.exists() && !file.delete()) {
+ logWarning(s"Error deleting data ${file.getPath()}")
}
file = getIndexFile(shuffleId, mapId)
- if (file.exists()) {
- if (!file.delete()) {
- logWarning(s"Error deleting index ${file.getPath()}")
- }
+ if (file.exists() && !file.delete()) {
+ logWarning(s"Error deleting index ${file.getPath()}")
+ }
+
+ file = getChecksumFile(shuffleId, mapId)
+ if (file.exists() && !file.delete()) {
+ logWarning(s"Error deleting checksum ${file.getPath()}")
}
}
@@ -200,6 +213,13 @@ private[spark] class IndexShuffleBlockResolver(
*/
override def putShuffleBlockAsStream(blockId: BlockId, serializerManager: SerializerManager):
StreamCallbackWithID = {
+ // Throw an exception if we have exceeded maximum shuffle files stored
+ remoteShuffleMaxDisk.foreach { maxBytes =>
+ val bytesUsed = getShuffleBytesStored()
+ if (maxBytes < bytesUsed) {
+ throw new SparkException(s"Not storing remote shuffles $bytesUsed exceeds $maxBytes")
+ }
+ }
val file = blockId match {
case ShuffleIndexBlockId(shuffleId, mapId, _) =>
getIndexFile(shuffleId, mapId)
@@ -273,7 +293,7 @@ private[spark] class IndexShuffleBlockResolver(
throw new FileNotFoundException("Index file is deleted already.")
}
if (dataFile.exists()) {
- List((indexBlockId, indexBlockData), (dataBlockId, dataBlockData))
+ List((dataBlockId, dataBlockData), (indexBlockId, indexBlockData))
} else {
List((indexBlockId, indexBlockData))
}
@@ -287,22 +307,41 @@ private[spark] class IndexShuffleBlockResolver(
/**
- * Write an index file with the offsets of each block, plus a final offset at the end for the
- * end of the output file. This will be used by getBlockData to figure out where each block
- * begins and ends.
+ * Commit the data and metadata files as an atomic operation, use the existing ones, or
+ * replace them with new ones. Note that the metadata parameters (`lengths`, `checksums`)
+ * will be updated to match the existing ones if use the existing ones.
*
- * It will commit the data and index file as an atomic operation, use the existing ones, or
- * replace them with new ones.
+ * There're two kinds of metadata files:
*
- * Note: the `lengths` will be updated to match the existing index file if use the existing ones.
+ * - index file
+ * An index file contains the offsets of each block, plus a final offset at the end
+ * for the end of the output file. It will be used by [[getBlockData]] to figure out
+ * where each block begins and ends.
+ *
+ * - checksum file (optional)
+ * An checksum file contains the checksum of each block. It will be used to diagnose
+ * the cause when a block is corrupted. Note that empty `checksums` indicate that
+ * checksum is disabled.
*/
- def writeIndexFileAndCommit(
+ def writeMetadataFileAndCommit(
shuffleId: Int,
mapId: Long,
lengths: Array[Long],
+ checksums: Array[Long],
dataTmp: File): Unit = {
val indexFile = getIndexFile(shuffleId, mapId)
val indexTmp = Utils.tempFileWith(indexFile)
+
+ val checksumEnabled = checksums.nonEmpty
+ val (checksumFileOpt, checksumTmpOpt) = if (checksumEnabled) {
+ assert(lengths.length == checksums.length,
+ "The size of partition lengths and checksums should be equal")
+ val checksumFile = getChecksumFile(shuffleId, mapId)
+ (Some(checksumFile), Some(Utils.tempFileWith(checksumFile)))
+ } else {
+ (None, None)
+ }
+
try {
val dataFile = getDataFile(shuffleId, mapId)
// There is only one IndexShuffleBlockResolver per executor, this synchronization make sure
@@ -313,37 +352,47 @@ private[spark] class IndexShuffleBlockResolver(
// Another attempt for the same task has already written our map outputs successfully,
// so just use the existing partition lengths and delete our temporary map outputs.
System.arraycopy(existingLengths, 0, lengths, 0, lengths.length)
+ if (checksumEnabled) {
+ val existingChecksums = getChecksums(checksumFileOpt.get, checksums.length)
+ if (existingChecksums != null) {
+ System.arraycopy(existingChecksums, 0, checksums, 0, lengths.length)
+ } else {
+ // It's possible that the previous task attempt succeeded writing the
+ // index file and data file but failed to write the checksum file. In
+ // this case, the current task attempt could write the missing checksum
+ // file by itself.
+ writeMetadataFile(checksums, checksumTmpOpt.get, checksumFileOpt.get, false)
+ }
+ }
if (dataTmp != null && dataTmp.exists()) {
dataTmp.delete()
}
} else {
// This is the first successful attempt in writing the map outputs for this task,
// so override any existing index and data files with the ones we wrote.
- val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexTmp)))
- Utils.tryWithSafeFinally {
- // We take in lengths of each block, need to convert it to offsets.
- var offset = 0L
- out.writeLong(offset)
- for (length <- lengths) {
- offset += length
- out.writeLong(offset)
- }
- } {
- out.close()
- }
- if (indexFile.exists()) {
- indexFile.delete()
- }
+ val offsets = lengths.scanLeft(0L)(_ + _)
+ writeMetadataFile(offsets, indexTmp, indexFile, true)
+
if (dataFile.exists()) {
dataFile.delete()
}
- if (!indexTmp.renameTo(indexFile)) {
- throw new IOException("fail to rename file " + indexTmp + " to " + indexFile)
- }
if (dataTmp != null && dataTmp.exists() && !dataTmp.renameTo(dataFile)) {
throw new IOException("fail to rename file " + dataTmp + " to " + dataFile)
}
+
+ // write the checksum file
+ checksumTmpOpt.zip(checksumFileOpt).foreach { case (checksumTmp, checksumFile) =>
+ try {
+ writeMetadataFile(checksums, checksumTmp, checksumFile, false)
+ } catch {
+ case e: Exception =>
+ // It's not worthwhile to fail here after index file and data file are
+ // already successfully stored since checksum is only a best-effort for
+ // the corner error case.
+ logError("Failed to write checksum file", e)
+ }
+ }
}
}
} finally {
@@ -351,6 +400,63 @@ private[spark] class IndexShuffleBlockResolver(
if (indexTmp.exists() && !indexTmp.delete()) {
logError(s"Failed to delete temporary index file at ${indexTmp.getAbsolutePath}")
}
+ checksumTmpOpt.foreach { checksumTmp =>
+ if (checksumTmp.exists()) {
+ try {
+ if (!checksumTmp.delete()) {
+ logError(s"Failed to delete temporary checksum file " +
+ s"at ${checksumTmp.getAbsolutePath}")
+ }
+ } catch {
+ case e: Exception =>
+ // Unlike index deletion, we won't propagate the error for the checksum file since
+ // checksum is only a best-effort.
+ logError(s"Failed to delete temporary checksum file " +
+ s"at ${checksumTmp.getAbsolutePath}", e)
+ }
+ }
+ }
+ }
+ }
+
+ /**
+ * Write the metadata file (index or checksum). Metadata values will be firstly write into
+ * the tmp file and the tmp file will be renamed to the target file at the end to avoid dirty
+ * writes.
+ * @param metaValues The metadata values
+ * @param tmpFile The temp file
+ * @param targetFile The target file
+ * @param propagateError Whether to propagate the error for file operation. Unlike index file,
+ * checksum is only a best-effort so we won't fail the whole task due to
+ * the error from checksum.
+ */
+ private def writeMetadataFile(
+ metaValues: Array[Long],
+ tmpFile: File,
+ targetFile: File,
+ propagateError: Boolean): Unit = {
+ val out = new DataOutputStream(
+ new BufferedOutputStream(
+ new FileOutputStream(tmpFile)
+ )
+ )
+ Utils.tryWithSafeFinally {
+ metaValues.foreach(out.writeLong)
+ } {
+ out.close()
+ }
+
+ if (targetFile.exists()) {
+ targetFile.delete()
+ }
+
+ if (!tmpFile.renameTo(targetFile)) {
+ val errorMsg = s"fail to rename file $tmpFile to $targetFile"
+ if (propagateError) {
+ throw new IOException(errorMsg)
+ } else {
+ logWarning(errorMsg)
+ }
}
}
@@ -398,6 +504,45 @@ private[spark] class IndexShuffleBlockResolver(
new MergedBlockMeta(numChunks, chunkBitMaps)
}
+ private[shuffle] def getChecksums(checksumFile: File, blockNum: Int): Array[Long] = {
+ if (!checksumFile.exists()) return null
+ val checksums = new ArrayBuffer[Long]
+ // Read the checksums of blocks
+ var in: DataInputStream = null
+ try {
+ in = new DataInputStream(new NioBufferedFileInputStream(checksumFile))
+ while (checksums.size < blockNum) {
+ checksums += in.readLong()
+ }
+ } catch {
+ case _: IOException | _: EOFException =>
+ return null
+ } finally {
+ in.close()
+ }
+
+ checksums.toArray
+ }
+
+ /**
+ * Get the shuffle checksum file.
+ *
+ * When the dirs parameter is None then use the disk manager's local directories. Otherwise,
+ * read from the specified directories.
+ */
+ def getChecksumFile(
+ shuffleId: Int,
+ mapId: Long,
+ dirs: Option[Array[String]] = None): File = {
+ val blockId = ShuffleChecksumBlockId(shuffleId, mapId, NOOP_REDUCE_ID)
+ val fileName = ShuffleChecksumHelper.getChecksumFileName(blockId, conf)
+ dirs
+ .map(ExecutorDiskUtils.getFile(_, blockManager.subDirsPerLocalDir, fileName))
+ .getOrElse {
+ blockManager.diskBlockManager.getFile(fileName)
+ }
+ }
+
override def getBlockData(
blockId: BlockId,
dirs: Option[Array[String]]): ManagedBuffer = {
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShufflePartitionPairsWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/ShufflePartitionPairsWriter.scala
index e0affb858c359..9843ae18bba77 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/ShufflePartitionPairsWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShufflePartitionPairsWriter.scala
@@ -18,7 +18,9 @@
package org.apache.spark.shuffle
import java.io.{Closeable, IOException, OutputStream}
+import java.util.zip.Checksum
+import org.apache.spark.io.MutableCheckedOutputStream
import org.apache.spark.serializer.{SerializationStream, SerializerInstance, SerializerManager}
import org.apache.spark.shuffle.api.ShufflePartitionWriter
import org.apache.spark.storage.{BlockId, TimeTrackingOutputStream}
@@ -34,7 +36,8 @@ private[spark] class ShufflePartitionPairsWriter(
serializerManager: SerializerManager,
serializerInstance: SerializerInstance,
blockId: BlockId,
- writeMetrics: ShuffleWriteMetricsReporter)
+ writeMetrics: ShuffleWriteMetricsReporter,
+ checksum: Checksum)
extends PairsWriter with Closeable {
private var isClosed = false
@@ -44,6 +47,9 @@ private[spark] class ShufflePartitionPairsWriter(
private var objOut: SerializationStream = _
private var numRecordsWritten = 0
private var curNumBytesWritten = 0L
+ // this would be only initialized when checksum != null,
+ // which indicates shuffle checksum is enabled.
+ private var checksumOutputStream: MutableCheckedOutputStream = _
override def write(key: Any, value: Any): Unit = {
if (isClosed) {
@@ -61,7 +67,12 @@ private[spark] class ShufflePartitionPairsWriter(
try {
partitionStream = partitionWriter.openStream
timeTrackingStream = new TimeTrackingOutputStream(writeMetrics, partitionStream)
- wrappedStream = serializerManager.wrapStream(blockId, timeTrackingStream)
+ if (checksum != null) {
+ checksumOutputStream = new MutableCheckedOutputStream(timeTrackingStream)
+ checksumOutputStream.setChecksum(checksum)
+ }
+ wrappedStream = serializerManager.wrapStream(blockId,
+ if (checksumOutputStream != null) checksumOutputStream else timeTrackingStream)
objOut = serializerInstance.serializeStream(wrappedStream)
} catch {
case e: Exception =>
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
index adbe6eca5800b..3cbf30160efb8 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
@@ -68,7 +68,7 @@ private[spark] class SortShuffleWriter[K, V, C](
val mapOutputWriter = shuffleExecutorComponents.createMapOutputWriter(
dep.shuffleId, mapId, dep.partitioner.numPartitions)
sorter.writePartitionedMapOutput(dep.shuffleId, mapId, mapOutputWriter)
- partitionLengths = mapOutputWriter.commitAllPartitions().getPartitionLengths
+ partitionLengths = mapOutputWriter.commitAllPartitions(sorter.getChecksums).getPartitionLengths
mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, mapId)
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala
index dc70a9af7e9c3..db5862dec2fbe 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala
@@ -92,6 +92,12 @@ case class ShuffleIndexBlockId(shuffleId: Int, mapId: Long, reduceId: Int) exten
override def name: String = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".index"
}
+@Since("3.3.0")
+@DeveloperApi
+case class ShuffleChecksumBlockId(shuffleId: Int, mapId: Long, reduceId: Int) extends BlockId {
+ override def name: String = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".checksum"
+}
+
@Since("3.2.0")
@DeveloperApi
case class ShufflePushBlockId(shuffleId: Int, mapIndex: Int, reduceId: Int) extends BlockId {
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
index e55c09274cd9a..f5d8c0219dc83 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
@@ -19,8 +19,10 @@ package org.apache.spark.storage
import java.io.{BufferedOutputStream, File, FileOutputStream, OutputStream}
import java.nio.channels.{ClosedByInterruptException, FileChannel}
+import java.util.zip.Checksum
import org.apache.spark.internal.Logging
+import org.apache.spark.io.MutableCheckedOutputStream
import org.apache.spark.serializer.{SerializationStream, SerializerInstance, SerializerManager}
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter
import org.apache.spark.util.Utils
@@ -77,6 +79,11 @@ private[spark] class DiskBlockObjectWriter(
private var streamOpen = false
private var hasBeenClosed = false
+ // checksum related
+ private var checksumEnabled = false
+ private var checksumOutputStream: MutableCheckedOutputStream = _
+ private var checksum: Checksum = _
+
/**
* Cursors used to represent positions in the file.
*
@@ -101,12 +108,30 @@ private[spark] class DiskBlockObjectWriter(
*/
private var numRecordsWritten = 0
+ /**
+ * Set the checksum that the checksumOutputStream should use
+ */
+ def setChecksum(checksum: Checksum): Unit = {
+ if (checksumOutputStream == null) {
+ this.checksumEnabled = true
+ this.checksum = checksum
+ } else {
+ checksumOutputStream.setChecksum(checksum)
+ }
+ }
+
private def initialize(): Unit = {
fos = new FileOutputStream(file, true)
channel = fos.getChannel()
ts = new TimeTrackingOutputStream(writeMetrics, fos)
+ if (checksumEnabled) {
+ assert(this.checksum != null, "Checksum is not set")
+ checksumOutputStream = new MutableCheckedOutputStream(ts)
+ checksumOutputStream.setChecksum(checksum)
+ }
class ManualCloseBufferedOutputStream
- extends BufferedOutputStream(ts, bufferSize) with ManualCloseOutputStream
+ extends BufferedOutputStream(if (checksumEnabled) checksumOutputStream else ts, bufferSize)
+ with ManualCloseOutputStream
mcs = new ManualCloseBufferedOutputStream
}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
index dc39170ecf382..3b54cbbc51cf6 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
@@ -31,6 +31,7 @@ import org.apache.spark.internal.{config, Logging}
import org.apache.spark.serializer._
import org.apache.spark.shuffle.ShufflePartitionPairsWriter
import org.apache.spark.shuffle.api.{ShuffleMapOutputWriter, ShufflePartitionWriter}
+import org.apache.spark.shuffle.checksum.ShuffleChecksumHelper
import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter, ShuffleBlockId}
import org.apache.spark.util.{Utils => TryUtils}
@@ -141,6 +142,11 @@ private[spark] class ExternalSorter[K, V, C](
private val forceSpillFiles = new ArrayBuffer[SpilledFile]
@volatile private var readingIterator: SpillableIterator = null
+ private val partitionChecksums =
+ ShuffleChecksumHelper.createPartitionChecksumsIfEnabled(numPartitions, conf)
+
+ def getChecksums: Array[Long] = ShuffleChecksumHelper.getChecksumValues(partitionChecksums)
+
// A comparator for keys K that orders them within a partition to allow aggregation or sorting.
// Can be a partial ordering by hash code if a total ordering is not provided through by the
// user. (A partial ordering means that equal keys have comparator.compare(k, k) = 0, but some
@@ -746,7 +752,8 @@ private[spark] class ExternalSorter[K, V, C](
serializerManager,
serInstance,
blockId,
- context.taskMetrics().shuffleWriteMetrics)
+ context.taskMetrics().shuffleWriteMetrics,
+ if (partitionChecksums.nonEmpty) partitionChecksums(partitionId) else null)
while (it.hasNext && it.nextPartition() == partitionId) {
it.writeNext(partitionPairsWriter)
}
@@ -770,7 +777,8 @@ private[spark] class ExternalSorter[K, V, C](
serializerManager,
serInstance,
blockId,
- context.taskMetrics().shuffleWriteMetrics)
+ context.taskMetrics().shuffleWriteMetrics,
+ if (partitionChecksums.nonEmpty) partitionChecksums(id) else null)
if (elements.hasNext) {
for (elem <- elements) {
partitionPairsWriter.write(elem._1, elem._2)
diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
index df1d306e628a9..ecb14a858009e 100644
--- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
+++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
@@ -22,11 +22,11 @@
import java.nio.file.Files;
import java.util.*;
+import org.apache.spark.*;
+import org.apache.spark.shuffle.ShuffleChecksumTestHelper;
+import org.apache.spark.shuffle.checksum.ShuffleChecksumHelper;
import org.mockito.stubbing.Answer;
-import scala.Option;
-import scala.Product2;
-import scala.Tuple2;
-import scala.Tuple2$;
+import scala.*;
import scala.collection.Iterator;
import com.google.common.collect.HashMultiset;
@@ -36,10 +36,6 @@
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
-import org.apache.spark.HashPartitioner;
-import org.apache.spark.ShuffleDependency;
-import org.apache.spark.SparkConf;
-import org.apache.spark.TaskContext;
import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.executor.TaskMetrics;
import org.apache.spark.io.CompressionCodec$;
@@ -65,7 +61,7 @@
import static org.mockito.Answers.RETURNS_SMART_NULLS;
import static org.mockito.Mockito.*;
-public class UnsafeShuffleWriterSuite {
+public class UnsafeShuffleWriterSuite implements ShuffleChecksumTestHelper {
static final int DEFAULT_INITIAL_SORT_BUFFER_SIZE = 4096;
static final int NUM_PARTITIONS = 4;
@@ -138,7 +134,7 @@ public void setUp() throws IOException {
Answer> renameTempAnswer = invocationOnMock -> {
partitionSizesInMergedFile = (long[]) invocationOnMock.getArguments()[2];
- File tmp = (File) invocationOnMock.getArguments()[3];
+ File tmp = (File) invocationOnMock.getArguments()[4];
if (!mergedOutputFile.delete()) {
throw new RuntimeException("Failed to delete old merged output file.");
}
@@ -152,11 +148,13 @@ public void setUp() throws IOException {
doAnswer(renameTempAnswer)
.when(shuffleBlockResolver)
- .writeIndexFileAndCommit(anyInt(), anyLong(), any(long[].class), any(File.class));
+ .writeMetadataFileAndCommit(
+ anyInt(), anyLong(), any(long[].class), any(long[].class), any(File.class));
doAnswer(renameTempAnswer)
.when(shuffleBlockResolver)
- .writeIndexFileAndCommit(anyInt(), anyLong(), any(long[].class), eq(null));
+ .writeMetadataFileAndCommit(
+ anyInt(), anyLong(), any(long[].class), any(long[].class), eq(null));
when(diskBlockManager.createTempShuffleBlock()).thenAnswer(invocationOnMock -> {
TempShuffleBlockId blockId = new TempShuffleBlockId(UUID.randomUUID());
@@ -171,7 +169,14 @@ public void setUp() throws IOException {
when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager);
}
- private UnsafeShuffleWriter