Skip to content

Commit

Permalink
[SPARK-35276][CORE] Calculate checksum for shuffle data and write as …
Browse files Browse the repository at this point in the history
…checksum file

### What changes were proposed in this pull request?

This is the initial work of add checksum support of shuffle. This is a piece of apache#32385. And this PR only adds checksum functionality at the shuffle writer side.

Basically, the idea is to wrap a `MutableCheckedOutputStream`* upon the `FileOutputStream` while the shuffle writer generating the shuffle data. But the specific wrapping places are a bit different among the shuffle writers due to their different implementation:

* `BypassMergeSortShuffleWriter` -  wrap on each partition file
* `UnsafeShuffleWriter` - wrap on each spill files directly since they doesn't require aggregation, sorting
* `SortShuffleWriter` - wrap on the `ShufflePartitionPairsWriter` after merged spill files since they might require aggregation, sorting

\* `MutableCheckedOutputStream` is a variant of `java.util.zip.CheckedOutputStream` which can change the checksum calculator at runtime.

And we use the `Adler32`, which uses the CRC-32 algorithm but much faster, to calculate the checksum as the same as `Broadcast`'s checksum.

### Why are the changes needed?

### Does this PR introduce _any_ user-facing change?

Yes, added a new conf: `spark.shuffle.checksum`.

### How was this patch tested?

Added unit tests.

Closes apache#32401 from Ngone51/add-checksum-files.

Authored-by: yi.wu <yi.wu@databricks.com>
Signed-off-by: Mridul Muralidharan <mridul<at>gmail.com>
  • Loading branch information
Ngone51 authored and Mridul Muralidharan committed Jul 17, 2021
1 parent 37dc3f9 commit 4783fb7
Show file tree
Hide file tree
Showing 25 changed files with 785 additions and 106 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ public interface ShuffleMapOutputWriter {
* available to downstream reduce tasks. If this method throws any exception, this module's
* {@link #abort(Throwable)} method will be invoked before propagating the exception.
* <p>
* Shuffle extensions which care about the cause of shuffle data corruption should store
* the checksums properly. When corruption happens, Spark would provide the checksum
* of the fetched partition to the shuffle extension to help diagnose the cause of corruption.
* <p>
* This can also close any resources and clean up temporary state if necessary.
* <p>
* The returned commit message is a structure with two components:
Expand All @@ -68,8 +72,11 @@ public interface ShuffleMapOutputWriter {
* for that partition id.
* <p>
* 2) An optional metadata blob that can be used by shuffle readers.
*
* @param checksums The checksum values for each partition (where checksum index is equivalent to
* partition id) if shuffle checksum enabled. Otherwise, it's empty.
*/
MapOutputCommitMessage commitAllPartitions() throws IOException;
MapOutputCommitMessage commitAllPartitions(long[] checksums) throws IOException;

/**
* Abort all of the writes done by any writers returned by {@link #getPartitionWriter(int)}.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public interface ShufflePartitionWriter {
* by the parent {@link ShuffleMapOutputWriter}. If one does so, ensure that
* {@link OutputStream#close()} does not close the resource, since it will be reused across
* partition writes. The underlying resources should be cleaned up in
* {@link ShuffleMapOutputWriter#commitAllPartitions()} and
* {@link ShuffleMapOutputWriter#commitAllPartitions(long[])} and
* {@link ShuffleMapOutputWriter#abort(Throwable)}.
*/
OutputStream openStream() throws IOException;
Expand All @@ -68,7 +68,7 @@ public interface ShufflePartitionWriter {
* by the parent {@link ShuffleMapOutputWriter}. If one does so, ensure that
* {@link WritableByteChannelWrapper#close()} does not close the resource, since the channel
* will be reused across partition writes. The underlying resources should be cleaned up in
* {@link ShuffleMapOutputWriter#commitAllPartitions()} and
* {@link ShuffleMapOutputWriter#commitAllPartitions(long[])} and
* {@link ShuffleMapOutputWriter#abort(Throwable)}.
* <p>
* This method is primarily for advanced optimizations where bytes can be copied from the input
Expand All @@ -79,7 +79,7 @@ public interface ShufflePartitionWriter {
* Note that the returned {@link WritableByteChannelWrapper} itself is closed, but not the
* underlying channel that is returned by {@link WritableByteChannelWrapper#channel()}. Ensure
* that the underlying channel is cleaned up in {@link WritableByteChannelWrapper#close()},
* {@link ShuffleMapOutputWriter#commitAllPartitions()}, or
* {@link ShuffleMapOutputWriter#commitAllPartitions(long[])}, or
* {@link ShuffleMapOutputWriter#abort(Throwable)}.
*/
default Optional<WritableByteChannelWrapper> openChannelWrapper() throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,8 @@ public interface SingleSpillShuffleMapOutputWriter {
/**
* Transfer a file that contains the bytes of all the partitions written by this map task.
*/
void transferMapSpillFile(File mapOutputFile, long[] partitionLengths) throws IOException;
void transferMapSpillFile(
File mapOutputFile,
long[] partitionLengths,
long[] checksums) throws IOException;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.shuffle.checksum;

import java.util.zip.Adler32;
import java.util.zip.CRC32;
import java.util.zip.Checksum;

import org.apache.spark.SparkConf;
import org.apache.spark.SparkException;
import org.apache.spark.annotation.Private;
import org.apache.spark.internal.config.package$;
import org.apache.spark.storage.ShuffleChecksumBlockId;

/**
* A set of utility functions for the shuffle checksum.
*/
@Private
public class ShuffleChecksumHelper {

/** Used when the checksum is disabled for shuffle. */
private static final Checksum[] EMPTY_CHECKSUM = new Checksum[0];
public static final long[] EMPTY_CHECKSUM_VALUE = new long[0];

public static boolean isShuffleChecksumEnabled(SparkConf conf) {
return (boolean) conf.get(package$.MODULE$.SHUFFLE_CHECKSUM_ENABLED());
}

public static Checksum[] createPartitionChecksumsIfEnabled(int numPartitions, SparkConf conf)
throws SparkException {
if (!isShuffleChecksumEnabled(conf)) {
return EMPTY_CHECKSUM;
}

String checksumAlgo = shuffleChecksumAlgorithm(conf);
return getChecksumByAlgorithm(numPartitions, checksumAlgo);
}

private static Checksum[] getChecksumByAlgorithm(int num, String algorithm)
throws SparkException {
Checksum[] checksums;
switch (algorithm) {
case "ADLER32":
checksums = new Adler32[num];
for (int i = 0; i < num; i ++) {
checksums[i] = new Adler32();
}
return checksums;

case "CRC32":
checksums = new CRC32[num];
for (int i = 0; i < num; i ++) {
checksums[i] = new CRC32();
}
return checksums;

default:
throw new SparkException("Unsupported shuffle checksum algorithm: " + algorithm);
}
}

public static long[] getChecksumValues(Checksum[] partitionChecksums) {
int numPartitions = partitionChecksums.length;
long[] checksumValues = new long[numPartitions];
for (int i = 0; i < numPartitions; i ++) {
checksumValues[i] = partitionChecksums[i].getValue();
}
return checksumValues;
}

public static String shuffleChecksumAlgorithm(SparkConf conf) {
return conf.get(package$.MODULE$.SHUFFLE_CHECKSUM_ALGORITHM());
}

public static Checksum getChecksumByFileExtension(String fileName) throws SparkException {
int index = fileName.lastIndexOf(".");
String algorithm = fileName.substring(index + 1);
return getChecksumByAlgorithm(1, algorithm)[0];
}

public static String getChecksumFileName(ShuffleChecksumBlockId blockId, SparkConf conf) {
// append the shuffle checksum algorithm as the file extension
return String.format("%s.%s", blockId.name(), shuffleChecksumAlgorithm(conf));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.io.OutputStream;
import java.nio.channels.FileChannel;
import java.util.Optional;
import java.util.zip.Checksum;
import javax.annotation.Nullable;

import scala.None$;
Expand All @@ -38,6 +39,7 @@
import org.apache.spark.Partitioner;
import org.apache.spark.ShuffleDependency;
import org.apache.spark.SparkConf;
import org.apache.spark.SparkException;
import org.apache.spark.shuffle.api.ShuffleExecutorComponents;
import org.apache.spark.shuffle.api.ShuffleMapOutputWriter;
import org.apache.spark.shuffle.api.ShufflePartitionWriter;
Expand All @@ -49,6 +51,7 @@
import org.apache.spark.serializer.SerializerInstance;
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
import org.apache.spark.shuffle.ShuffleWriter;
import org.apache.spark.shuffle.checksum.ShuffleChecksumHelper;
import org.apache.spark.storage.*;
import org.apache.spark.util.Utils;

Expand Down Expand Up @@ -93,6 +96,8 @@ final class BypassMergeSortShuffleWriter<K, V> extends ShuffleWriter<K, V> {
private FileSegment[] partitionWriterSegments;
@Nullable private MapStatus mapStatus;
private long[] partitionLengths;
/** Checksum calculator for each partition. Empty when shuffle checksum disabled. */
private final Checksum[] partitionChecksums;

/**
* Are we in the process of stopping? Because map tasks can call stop() with success = true
Expand All @@ -107,7 +112,7 @@ final class BypassMergeSortShuffleWriter<K, V> extends ShuffleWriter<K, V> {
long mapId,
SparkConf conf,
ShuffleWriteMetricsReporter writeMetrics,
ShuffleExecutorComponents shuffleExecutorComponents) {
ShuffleExecutorComponents shuffleExecutorComponents) throws SparkException {
// Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
this.fileBufferSize = (int) (long) conf.get(package$.MODULE$.SHUFFLE_FILE_BUFFER_SIZE()) * 1024;
this.transferToEnabled = conf.getBoolean("spark.file.transferTo", true);
Expand All @@ -120,6 +125,8 @@ final class BypassMergeSortShuffleWriter<K, V> extends ShuffleWriter<K, V> {
this.writeMetrics = writeMetrics;
this.serializer = dep.serializer();
this.shuffleExecutorComponents = shuffleExecutorComponents;
this.partitionChecksums =
ShuffleChecksumHelper.createPartitionChecksumsIfEnabled(numPartitions, conf);
}

@Override
Expand All @@ -129,7 +136,8 @@ public void write(Iterator<Product2<K, V>> records) throws IOException {
.createMapOutputWriter(shuffleId, mapId, numPartitions);
try {
if (!records.hasNext()) {
partitionLengths = mapOutputWriter.commitAllPartitions().getPartitionLengths();
partitionLengths = mapOutputWriter.commitAllPartitions(
ShuffleChecksumHelper.EMPTY_CHECKSUM_VALUE).getPartitionLengths();
mapStatus = MapStatus$.MODULE$.apply(
blockManager.shuffleServerId(), partitionLengths, mapId);
return;
Expand All @@ -143,8 +151,12 @@ public void write(Iterator<Product2<K, V>> records) throws IOException {
blockManager.diskBlockManager().createTempShuffleBlock();
final File file = tempShuffleBlockIdPlusFile._2();
final BlockId blockId = tempShuffleBlockIdPlusFile._1();
partitionWriters[i] =
blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, writeMetrics);
DiskBlockObjectWriter writer =
blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, writeMetrics);
if (partitionChecksums.length > 0) {
writer.setChecksum(partitionChecksums[i]);
}
partitionWriters[i] = writer;
}
// Creating the file to write to and creating a disk writer both involve interacting with
// the disk, and can take a long time in aggregate when we open many files, so should be
Expand Down Expand Up @@ -218,7 +230,9 @@ private long[] writePartitionedData(ShuffleMapOutputWriter mapOutputWriter) thro
}
partitionWriters = null;
}
return mapOutputWriter.commitAllPartitions().getPartitionLengths();
return mapOutputWriter.commitAllPartitions(
ShuffleChecksumHelper.getChecksumValues(partitionChecksums)
).getPartitionLengths();
}

private void writePartitionedDataWithChannel(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
import java.io.File;
import java.io.IOException;
import java.util.LinkedList;
import java.util.zip.Checksum;

import org.apache.spark.SparkException;
import scala.Tuple2;

import com.google.common.annotations.VisibleForTesting;
Expand All @@ -39,6 +41,7 @@
import org.apache.spark.serializer.DummySerializerInstance;
import org.apache.spark.serializer.SerializerInstance;
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
import org.apache.spark.shuffle.checksum.ShuffleChecksumHelper;
import org.apache.spark.storage.BlockManager;
import org.apache.spark.storage.DiskBlockObjectWriter;
import org.apache.spark.storage.FileSegment;
Expand Down Expand Up @@ -107,14 +110,17 @@ final class ShuffleExternalSorter extends MemoryConsumer {
@Nullable private MemoryBlock currentPage = null;
private long pageCursor = -1;

// Checksum calculator for each partition. Empty when shuffle checksum disabled.
private final Checksum[] partitionChecksums;

ShuffleExternalSorter(
TaskMemoryManager memoryManager,
BlockManager blockManager,
TaskContext taskContext,
int initialSize,
int numPartitions,
SparkConf conf,
ShuffleWriteMetricsReporter writeMetrics) {
ShuffleWriteMetricsReporter writeMetrics) throws SparkException {
super(memoryManager,
(int) Math.min(PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES, memoryManager.pageSizeBytes()),
memoryManager.getTungstenMemoryMode());
Expand All @@ -133,6 +139,12 @@ final class ShuffleExternalSorter extends MemoryConsumer {
this.peakMemoryUsedBytes = getMemoryUsage();
this.diskWriteBufferSize =
(int) (long) conf.get(package$.MODULE$.SHUFFLE_DISK_WRITE_BUFFER_SIZE());
this.partitionChecksums =
ShuffleChecksumHelper.createPartitionChecksumsIfEnabled(numPartitions, conf);
}

public long[] getChecksums() {
return ShuffleChecksumHelper.getChecksumValues(partitionChecksums);
}

/**
Expand Down Expand Up @@ -204,6 +216,9 @@ private void writeSortedFile(boolean isLastFile) {
spillInfo.partitionLengths[currentPartition] = fileSegment.length();
}
currentPartition = partition;
if (partitionChecksums.length > 0) {
writer.setChecksum(partitionChecksums[currentPartition]);
}
}

final long recordPointer = sortedRecords.packedRecordPointer.getRecordPointer();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
import org.apache.spark.shuffle.api.ShufflePartitionWriter;
import org.apache.spark.shuffle.api.SingleSpillShuffleMapOutputWriter;
import org.apache.spark.shuffle.api.WritableByteChannelWrapper;
import org.apache.spark.shuffle.checksum.ShuffleChecksumHelper;
import org.apache.spark.storage.BlockManager;
import org.apache.spark.storage.TimeTrackingOutputStream;
import org.apache.spark.unsafe.Platform;
Expand Down Expand Up @@ -115,7 +116,7 @@ public UnsafeShuffleWriter(
TaskContext taskContext,
SparkConf sparkConf,
ShuffleWriteMetricsReporter writeMetrics,
ShuffleExecutorComponents shuffleExecutorComponents) {
ShuffleExecutorComponents shuffleExecutorComponents) throws SparkException {
final int numPartitions = handle.dependency().partitioner().numPartitions();
if (numPartitions > SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE()) {
throw new IllegalArgumentException(
Expand Down Expand Up @@ -198,7 +199,7 @@ public void write(scala.collection.Iterator<Product2<K, V>> records) throws IOEx
}
}

private void open() {
private void open() throws SparkException {
assert (sorter == null);
sorter = new ShuffleExternalSorter(
memoryManager,
Expand All @@ -219,10 +220,10 @@ void closeAndWriteOutput() throws IOException {
serBuffer = null;
serOutputStream = null;
final SpillInfo[] spills = sorter.closeAndGetSpills();
sorter = null;
try {
partitionLengths = mergeSpills(spills);
} finally {
sorter = null;
for (SpillInfo spill : spills) {
if (spill.file.exists() && !spill.file.delete()) {
logger.error("Error while deleting spill file {}", spill.file.getPath());
Expand Down Expand Up @@ -267,7 +268,8 @@ private long[] mergeSpills(SpillInfo[] spills) throws IOException {
if (spills.length == 0) {
final ShuffleMapOutputWriter mapWriter = shuffleExecutorComponents
.createMapOutputWriter(shuffleId, mapId, partitioner.numPartitions());
return mapWriter.commitAllPartitions().getPartitionLengths();
return mapWriter.commitAllPartitions(
ShuffleChecksumHelper.EMPTY_CHECKSUM_VALUE).getPartitionLengths();
} else if (spills.length == 1) {
Optional<SingleSpillShuffleMapOutputWriter> maybeSingleFileWriter =
shuffleExecutorComponents.createSingleFileMapOutputWriter(shuffleId, mapId);
Expand All @@ -277,7 +279,8 @@ private long[] mergeSpills(SpillInfo[] spills) throws IOException {
partitionLengths = spills[0].partitionLengths;
logger.debug("Merge shuffle spills for mapId {} with length {}", mapId,
partitionLengths.length);
maybeSingleFileWriter.get().transferMapSpillFile(spills[0].file, partitionLengths);
maybeSingleFileWriter.get()
.transferMapSpillFile(spills[0].file, partitionLengths, sorter.getChecksums());
} else {
partitionLengths = mergeSpillsUsingStandardWriter(spills);
}
Expand Down Expand Up @@ -330,7 +333,7 @@ private long[] mergeSpillsUsingStandardWriter(SpillInfo[] spills) throws IOExcep
// to be counted as shuffle write, but this will lead to double-counting of the final
// SpillInfo's bytes.
writeMetrics.decBytesWritten(spills[spills.length - 1].file.length());
partitionLengths = mapWriter.commitAllPartitions().getPartitionLengths();
partitionLengths = mapWriter.commitAllPartitions(sorter.getChecksums()).getPartitionLengths();
} catch (Exception e) {
try {
mapWriter.abort(e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ public ShufflePartitionWriter getPartitionWriter(int reducePartitionId) throws I
}

@Override
public MapOutputCommitMessage commitAllPartitions() throws IOException {
public MapOutputCommitMessage commitAllPartitions(long[] checksums) throws IOException {
// Check the position after transferTo loop to see if it is in the right position and raise a
// exception if it is incorrect. The position will not be increased to the expected length
// after calling transferTo in kernel version 2.6.32. This issue is described at
Expand All @@ -115,7 +115,8 @@ public MapOutputCommitMessage commitAllPartitions() throws IOException {
File resolvedTmp = outputTempFile != null && outputTempFile.isFile() ? outputTempFile : null;
log.debug("Writing shuffle index file for mapId {} with length {}", mapId,
partitionLengths.length);
blockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, resolvedTmp);
blockResolver
.writeMetadataFileAndCommit(shuffleId, mapId, partitionLengths, checksums, resolvedTmp);
return MapOutputCommitMessage.of(partitionLengths);
}

Expand Down
Loading

0 comments on commit 4783fb7

Please sign in to comment.