diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java index 829884645d9d5..6dc5fd5a70f1a 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java @@ -33,9 +33,9 @@ import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.client.TransportClientFactory; -import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; -import org.apache.spark.network.shuffle.protocol.GetLocalDirsForExecutors; -import org.apache.spark.network.shuffle.protocol.LocalDirsForExecutors; +import org.apache.spark.network.shuffle.checksum.Cause; +import org.apache.spark.network.shuffle.protocol.*; +import org.apache.spark.network.util.TransportConf; /** * Provides an interface for reading both shuffle files and RDD blocks, either from an Executor @@ -46,6 +46,45 @@ public abstract class BlockStoreClient implements Closeable { protected volatile TransportClientFactory clientFactory; protected String appId; + protected TransportConf transportConf; + + /** + * Send the diagnosis request for the corrupted shuffle block to the server. + * + * @param host the host of the remote node. + * @param port the port of the remote node. + * @param execId the executor id. + * @param shuffleId the shuffleId of the corrupted shuffle block + * @param mapId the mapId of the corrupted shuffle block + * @param reduceId the reduceId of the corrupted shuffle block + * @param checksum the shuffle checksum which calculated at client side for the corrupted + * shuffle block + * @return The cause of the shuffle block corruption + */ + public Cause diagnoseCorruption( + String host, + int port, + String execId, + int shuffleId, + long mapId, + int reduceId, + long checksum, + String algorithm) { + try { + TransportClient client = clientFactory.createClient(host, port); + ByteBuffer response = client.sendRpcSync( + new DiagnoseCorruption(appId, execId, shuffleId, mapId, reduceId, checksum, algorithm) + .toByteBuffer(), + transportConf.connectionTimeoutMs() + ); + CorruptionCause cause = + (CorruptionCause) BlockTransferMessage.Decoder.fromByteBuffer(response); + return cause.cause; + } catch (Exception e) { + logger.warn("Failed to get the corruption cause."); + return Cause.UNKNOWN_ISSUE; + } + } /** * Fetch a sequence of blocks from a remote node asynchronously, diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java index cfabcd5ba4a28..71741f2cba053 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java @@ -17,7 +17,6 @@ package org.apache.spark.network.shuffle; -import com.google.common.base.Preconditions; import java.io.File; import java.io.IOException; import java.nio.ByteBuffer; @@ -35,8 +34,9 @@ import com.codahale.metrics.RatioGauge; import com.codahale.metrics.Timer; import com.codahale.metrics.Counter; -import com.google.common.collect.Sets; import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; +import com.google.common.collect.Sets; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -49,6 +49,7 @@ import org.apache.spark.network.server.OneForOneStreamManager; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.StreamManager; +import org.apache.spark.network.shuffle.checksum.Cause; import org.apache.spark.network.shuffle.protocol.*; import org.apache.spark.network.util.TimerWithCustomTimeUnit; import static org.apache.spark.network.util.NettyUtils.getRemoteAddress; @@ -223,6 +224,14 @@ protected void handleMessage( } finally { responseDelayContext.stop(); } + } else if (msgObj instanceof DiagnoseCorruption) { + DiagnoseCorruption msg = (DiagnoseCorruption) msgObj; + checkAuth(client, msg.appId); + Cause cause = blockManager.diagnoseShuffleBlockCorruption( + msg.appId, msg.execId, msg.shuffleId, msg.mapId, msg.reduceId, msg.checksum, msg.algorithm); + // In any cases of the error, diagnoseShuffleBlockCorruption should return UNKNOWN_ISSUE, + // so it should always reply as success. + callback.onSuccess(new CorruptionCause(cause).toByteBuffer()); } else { throw new UnsupportedOperationException("Unexpected message: " + msgObj); } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java index eb2d118b7d4fa..826402c081cce 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java @@ -49,7 +49,6 @@ public class ExternalBlockStoreClient extends BlockStoreClient { private static final ErrorHandler PUSH_ERROR_HANDLER = new ErrorHandler.BlockPushErrorHandler(); - private final TransportConf conf; private final boolean authEnabled; private final SecretKeyHolder secretKeyHolder; private final long registrationTimeoutMs; @@ -63,7 +62,7 @@ public ExternalBlockStoreClient( SecretKeyHolder secretKeyHolder, boolean authEnabled, long registrationTimeoutMs) { - this.conf = conf; + this.transportConf = conf; this.secretKeyHolder = secretKeyHolder; this.authEnabled = authEnabled; this.registrationTimeoutMs = registrationTimeoutMs; @@ -75,10 +74,11 @@ public ExternalBlockStoreClient( */ public void init(String appId) { this.appId = appId; - TransportContext context = new TransportContext(conf, new NoOpRpcHandler(), true, true); + TransportContext context = new TransportContext( + transportConf, new NoOpRpcHandler(), true, true); List bootstraps = Lists.newArrayList(); if (authEnabled) { - bootstraps.add(new AuthClientBootstrap(conf, appId, secretKeyHolder)); + bootstraps.add(new AuthClientBootstrap(transportConf, appId, secretKeyHolder)); } clientFactory = context.createClientFactory(bootstraps); } @@ -94,7 +94,7 @@ public void fetchBlocks( checkInit(); logger.debug("External shuffle fetch from {}:{} (executor id {})", host, port, execId); try { - int maxRetries = conf.maxIORetries(); + int maxRetries = transportConf.maxIORetries(); RetryingBlockTransferor.BlockTransferStarter blockFetchStarter = (inputBlockId, inputListener) -> { // Unless this client is closed. @@ -103,7 +103,7 @@ public void fetchBlocks( "Expecting a BlockFetchingListener, but got " + inputListener.getClass(); TransportClient client = clientFactory.createClient(host, port, maxRetries > 0); new OneForOneBlockFetcher(client, appId, execId, inputBlockId, - (BlockFetchingListener) inputListener, conf, downloadFileManager).start(); + (BlockFetchingListener) inputListener, transportConf, downloadFileManager).start(); } else { logger.info("This clientFactory was closed. Skipping further block fetch retries."); } @@ -112,7 +112,7 @@ public void fetchBlocks( if (maxRetries > 0) { // Note this Fetcher will correctly handle maxRetries == 0; we avoid it just in case there's // a bug in this code. We should remove the if statement once we're sure of the stability. - new RetryingBlockTransferor(conf, blockFetchStarter, blockIds, listener).start(); + new RetryingBlockTransferor(transportConf, blockFetchStarter, blockIds, listener).start(); } else { blockFetchStarter.createAndStart(blockIds, listener); } @@ -146,16 +146,16 @@ public void pushBlocks( assert inputListener instanceof BlockPushingListener : "Expecting a BlockPushingListener, but got " + inputListener.getClass(); TransportClient client = clientFactory.createClient(host, port); - new OneForOneBlockPusher(client, appId, conf.appAttemptId(), inputBlockId, + new OneForOneBlockPusher(client, appId, transportConf.appAttemptId(), inputBlockId, (BlockPushingListener) inputListener, buffersWithId).start(); } else { logger.info("This clientFactory was closed. Skipping further block push retries."); } }; - int maxRetries = conf.maxIORetries(); + int maxRetries = transportConf.maxIORetries(); if (maxRetries > 0) { new RetryingBlockTransferor( - conf, blockPushStarter, blockIds, listener, PUSH_ERROR_HANDLER).start(); + transportConf, blockPushStarter, blockIds, listener, PUSH_ERROR_HANDLER).start(); } else { blockPushStarter.createAndStart(blockIds, listener); } @@ -178,7 +178,7 @@ public void finalizeShuffleMerge( try { TransportClient client = clientFactory.createClient(host, port); ByteBuffer finalizeShuffleMerge = - new FinalizeShuffleMerge(appId, conf.appAttemptId(), shuffleId, + new FinalizeShuffleMerge(appId, transportConf.appAttemptId(), shuffleId, shuffleMergeId).toByteBuffer(); client.sendRpc(finalizeShuffleMerge, new RpcResponseCallback() { @Override diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java index 493edd2b34628..73d4e6ceb1951 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java @@ -45,6 +45,8 @@ import org.apache.spark.network.buffer.FileSegmentManagedBuffer; import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.shuffle.checksum.Cause; +import org.apache.spark.network.shuffle.checksum.ShuffleChecksumHelper; import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; import org.apache.spark.network.util.LevelDBProvider; import org.apache.spark.network.util.LevelDBProvider.StoreVersion; @@ -374,6 +376,29 @@ public Map getLocalDirs(String appId, Set execIds) { .collect(Collectors.toMap(Pair::getKey, Pair::getValue)); } + /** + * Diagnose the possible cause of the shuffle data corruption by verifying the shuffle checksums + */ + public Cause diagnoseShuffleBlockCorruption( + String appId, + String execId, + int shuffleId, + long mapId, + int reduceId, + long checksumByReader, + String algorithm) { + ExecutorShuffleInfo executor = executors.get(new AppExecId(appId, execId)); + // This should be in sync with IndexShuffleBlockResolver.getChecksumFile + String fileName = "shuffle_" + shuffleId + "_" + mapId + "_0.checksum." + algorithm; + File checksumFile = ExecutorDiskUtils.getFile( + executor.localDirs, + executor.subDirsPerLocalDir, + fileName); + ManagedBuffer data = getBlockData(appId, execId, shuffleId, mapId, reduceId); + return ShuffleChecksumHelper.diagnoseCorruption( + algorithm, checksumFile, reduceId, data, checksumByReader); + } + /** Simply encodes an executor's full ID, which is appId + execId. */ public static class AppExecId { public final String appId; diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/Cause.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/Cause.java new file mode 100644 index 0000000000000..d316737a16148 --- /dev/null +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/Cause.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.checksum; + +/** + * The cause of shuffle data corruption. + */ +public enum Cause { + DISK_ISSUE, NETWORK_ISSUE, UNKNOWN_ISSUE, CHECKSUM_VERIFY_PASS, UNSUPPORTED_CHECKSUM_ALGORITHM +} diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleChecksumHelper.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleChecksumHelper.java new file mode 100644 index 0000000000000..f332f740b3f5f --- /dev/null +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleChecksumHelper.java @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.checksum; + +import java.io.*; +import java.util.concurrent.TimeUnit; +import java.util.zip.Adler32; +import java.util.zip.CRC32; +import java.util.zip.CheckedInputStream; +import java.util.zip.Checksum; + +import com.google.common.io.ByteStreams; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.annotation.Private; +import org.apache.spark.network.buffer.ManagedBuffer; + +/** + * A set of utility functions for the shuffle checksum. + */ +@Private +public class ShuffleChecksumHelper { + private static final Logger logger = + LoggerFactory.getLogger(ShuffleChecksumHelper.class); + + public static final int CHECKSUM_CALCULATION_BUFFER = 8192; + public static final Checksum[] EMPTY_CHECKSUM = new Checksum[0]; + public static final long[] EMPTY_CHECKSUM_VALUE = new long[0]; + + public static Checksum[] createPartitionChecksums(int numPartitions, String algorithm) { + return getChecksumsByAlgorithm(numPartitions, algorithm); + } + + private static Checksum[] getChecksumsByAlgorithm(int num, String algorithm) { + Checksum[] checksums; + switch (algorithm) { + case "ADLER32": + checksums = new Adler32[num]; + for (int i = 0; i < num; i++) { + checksums[i] = new Adler32(); + } + return checksums; + + case "CRC32": + checksums = new CRC32[num]; + for (int i = 0; i < num; i++) { + checksums[i] = new CRC32(); + } + return checksums; + + default: + throw new UnsupportedOperationException( + "Unsupported shuffle checksum algorithm: " + algorithm); + } + } + + public static Checksum getChecksumByAlgorithm(String algorithm) { + return getChecksumsByAlgorithm(1, algorithm)[0]; + } + + public static String getChecksumFileName(String blockName, String algorithm) { + // append the shuffle checksum algorithm as the file extension + return String.format("%s.%s", blockName, algorithm); + } + + private static long readChecksumByReduceId(File checksumFile, int reduceId) throws IOException { + try (DataInputStream in = new DataInputStream(new FileInputStream(checksumFile))) { + ByteStreams.skipFully(in, reduceId * 8); + return in.readLong(); + } + } + + private static long calculateChecksumForPartition( + ManagedBuffer partitionData, + Checksum checksumAlgo) throws IOException { + InputStream in = partitionData.createInputStream(); + byte[] buffer = new byte[CHECKSUM_CALCULATION_BUFFER]; + try(CheckedInputStream checksumIn = new CheckedInputStream(in, checksumAlgo)) { + while (checksumIn.read(buffer, 0, CHECKSUM_CALCULATION_BUFFER) != -1) {} + return checksumAlgo.getValue(); + } + } + + /** + * Diagnose the possible cause of the shuffle data corruption by verifying the shuffle checksums. + * + * There're 3 different kinds of checksums for the same shuffle partition: + * - checksum (c1) that is calculated by the shuffle data reader + * - checksum (c2) that is calculated by the shuffle data writer and stored in the checksum file + * - checksum (c3) that is recalculated during diagnosis + * + * And the diagnosis mechanism works like this: + * If c2 != c3, we suspect the corruption is caused by the DISK_ISSUE. Otherwise, if c1 != c3, + * we suspect the corruption is caused by the NETWORK_ISSUE. Otherwise, the cause remains + * CHECKSUM_VERIFY_PASS. In case of the any other failures, the cause remains UNKNOWN_ISSUE. + * + * @param algorithm The checksum algorithm that is used for calculating checksum value + * of partitionData + * @param checksumFile The checksum file that written by the shuffle writer + * @param reduceId The reduceId of the shuffle block + * @param partitionData The partition data of the shuffle block + * @param checksumByReader The checksum value that calculated by the shuffle data reader + * @return The cause of data corruption + */ + public static Cause diagnoseCorruption( + String algorithm, + File checksumFile, + int reduceId, + ManagedBuffer partitionData, + long checksumByReader) { + Cause cause; + try { + long diagnoseStartNs = System.nanoTime(); + // Try to get the checksum instance before reading the checksum file so that + // `UnsupportedOperationException` can be thrown first before `FileNotFoundException` + // when the checksum algorithm isn't supported. + Checksum checksumAlgo = getChecksumByAlgorithm(algorithm); + long checksumByWriter = readChecksumByReduceId(checksumFile, reduceId); + long checksumByReCalculation = calculateChecksumForPartition(partitionData, checksumAlgo); + long duration = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - diagnoseStartNs); + logger.info("Shuffle corruption diagnosis took {} ms, checksum file {}", + duration, checksumFile.getAbsolutePath()); + if (checksumByWriter != checksumByReCalculation) { + cause = Cause.DISK_ISSUE; + } else if (checksumByWriter != checksumByReader) { + cause = Cause.NETWORK_ISSUE; + } else { + cause = Cause.CHECKSUM_VERIFY_PASS; + } + } catch (UnsupportedOperationException e) { + cause = Cause.UNSUPPORTED_CHECKSUM_ALGORITHM; + } catch (FileNotFoundException e) { + // Even if checksum is enabled, a checksum file may not exist if error throws during writing. + logger.warn("Checksum file " + checksumFile.getName() + " doesn't exit"); + cause = Cause.UNKNOWN_ISSUE; + } catch (Exception e) { + logger.warn("Unable to diagnose shuffle block corruption", e); + cause = Cause.UNKNOWN_ISSUE; + } + return cause; + } +} diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java index a55a6cf7ed939..453791da7bba2 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java @@ -49,7 +49,7 @@ public enum Type { HEARTBEAT(5), UPLOAD_BLOCK_STREAM(6), REMOVE_BLOCKS(7), BLOCKS_REMOVED(8), FETCH_SHUFFLE_BLOCKS(9), GET_LOCAL_DIRS_FOR_EXECUTORS(10), LOCAL_DIRS_FOR_EXECUTORS(11), PUSH_BLOCK_STREAM(12), FINALIZE_SHUFFLE_MERGE(13), MERGE_STATUSES(14), - FETCH_SHUFFLE_BLOCK_CHUNKS(15); + FETCH_SHUFFLE_BLOCK_CHUNKS(15), DIAGNOSE_CORRUPTION(16), CORRUPTION_CAUSE(17); private final byte id; @@ -84,6 +84,8 @@ public static BlockTransferMessage fromByteBuffer(ByteBuffer msg) { case 13: return FinalizeShuffleMerge.decode(buf); case 14: return MergeStatuses.decode(buf); case 15: return FetchShuffleBlockChunks.decode(buf); + case 16: return DiagnoseCorruption.decode(buf); + case 17: return CorruptionCause.decode(buf); default: throw new IllegalArgumentException("Unknown message type: " + type); } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/CorruptionCause.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/CorruptionCause.java new file mode 100644 index 0000000000000..5690eee53bd13 --- /dev/null +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/CorruptionCause.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol; + +import io.netty.buffer.ByteBuf; +import org.apache.commons.lang3.builder.ToStringBuilder; +import org.apache.commons.lang3.builder.ToStringStyle; + +import org.apache.spark.network.shuffle.checksum.Cause; + +/** Response to the {@link DiagnoseCorruption} */ +public class CorruptionCause extends BlockTransferMessage { + public Cause cause; + + public CorruptionCause(Cause cause) { + this.cause = cause; + } + + @Override + protected Type type() { + return Type.CORRUPTION_CAUSE; + } + + @Override + public String toString() { + return new ToStringBuilder(this, ToStringStyle.SHORT_PREFIX_STYLE) + .append("cause", cause) + .toString(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + CorruptionCause that = (CorruptionCause) o; + return cause == that.cause; + } + + @Override + public int hashCode() { + return cause.hashCode(); + } + + @Override + public int encodedLength() { + return 1; /* encoded length of cause */ + } + + @Override + public void encode(ByteBuf buf) { + buf.writeByte(cause.ordinal()); + } + + public static CorruptionCause decode(ByteBuf buf) { + int ordinal = buf.readByte(); + return new CorruptionCause(Cause.values()[ordinal]); + } +} diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/DiagnoseCorruption.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/DiagnoseCorruption.java new file mode 100644 index 0000000000000..620b5ad71cd75 --- /dev/null +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/DiagnoseCorruption.java @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol; + +import io.netty.buffer.ByteBuf; +import org.apache.commons.lang3.builder.ToStringBuilder; +import org.apache.commons.lang3.builder.ToStringStyle; +import org.apache.spark.network.protocol.Encoders; + +/** Request to get the cause of a corrupted block. Returns {@link CorruptionCause} */ +public class DiagnoseCorruption extends BlockTransferMessage { + public final String appId; + public final String execId; + public final int shuffleId; + public final long mapId; + public final int reduceId; + public final long checksum; + public final String algorithm; + + public DiagnoseCorruption( + String appId, + String execId, + int shuffleId, + long mapId, + int reduceId, + long checksum, + String algorithm) { + this.appId = appId; + this.execId = execId; + this.shuffleId = shuffleId; + this.mapId = mapId; + this.reduceId = reduceId; + this.checksum = checksum; + this.algorithm = algorithm; + } + + @Override + protected Type type() { + return Type.DIAGNOSE_CORRUPTION; + } + + @Override + public String toString() { + return new ToStringBuilder(this, ToStringStyle.SHORT_PREFIX_STYLE) + .append("appId", appId) + .append("execId", execId) + .append("shuffleId", shuffleId) + .append("mapId", mapId) + .append("reduceId", reduceId) + .append("checksum", checksum) + .append("algorithm", algorithm) + .toString(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + DiagnoseCorruption that = (DiagnoseCorruption) o; + + if (checksum != that.checksum) return false; + if (shuffleId != that.shuffleId) return false; + if (mapId != that.mapId) return false; + if (reduceId != that.reduceId) return false; + if (!algorithm.equals(that.algorithm)) return false; + if (!appId.equals(that.appId)) return false; + if (!execId.equals(that.execId)) return false; + return true; + } + + @Override + public int hashCode() { + int result = appId.hashCode(); + result = 31 * result + execId.hashCode(); + result = 31 * result + Integer.hashCode(shuffleId); + result = 31 * result + Long.hashCode(mapId); + result = 31 * result + Integer.hashCode(reduceId); + result = 31 * result + Long.hashCode(checksum); + result = 31 * result + algorithm.hashCode(); + return result; + } + + @Override + public int encodedLength() { + return Encoders.Strings.encodedLength(appId) + + Encoders.Strings.encodedLength(execId) + + 4 /* encoded length of shuffleId */ + + 8 /* encoded length of mapId */ + + 4 /* encoded length of reduceId */ + + 8 /* encoded length of checksum */ + + Encoders.Strings.encodedLength(algorithm); /* encoded length of algorithm */ + } + + @Override + public void encode(ByteBuf buf) { + Encoders.Strings.encode(buf, appId); + Encoders.Strings.encode(buf, execId); + buf.writeInt(shuffleId); + buf.writeLong(mapId); + buf.writeInt(reduceId); + buf.writeLong(checksum); + Encoders.Strings.encode(buf, algorithm); + } + + public static DiagnoseCorruption decode(ByteBuf buf) { + String appId = Encoders.Strings.decode(buf); + String execId = Encoders.Strings.decode(buf); + int shuffleId = buf.readInt(); + long mapId = buf.readLong(); + int reduceId = buf.readInt(); + long checksum = buf.readLong(); + String algorithm = Encoders.Strings.decode(buf); + return new DiagnoseCorruption(appId, execId, shuffleId, mapId, reduceId, checksum, algorithm); + } +} diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java index 9e0b3c65c9202..d45cbd5adcd98 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java @@ -17,14 +17,18 @@ package org.apache.spark.network.shuffle; -import java.io.IOException; +import java.io.*; import java.nio.ByteBuffer; import java.util.Iterator; import java.util.Map; +import java.util.zip.CheckedInputStream; +import java.util.zip.Checksum; import com.codahale.metrics.Meter; import com.codahale.metrics.Metric; import com.codahale.metrics.Timer; +import com.google.common.io.ByteStreams; +import com.google.common.io.Files; import org.junit.Before; import org.junit.Test; import org.mockito.ArgumentCaptor; @@ -42,7 +46,11 @@ import org.apache.spark.network.protocol.MergedBlockMetaRequest; import org.apache.spark.network.server.OneForOneStreamManager; import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.shuffle.checksum.Cause; +import org.apache.spark.network.shuffle.checksum.ShuffleChecksumHelper; import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; +import org.apache.spark.network.shuffle.protocol.CorruptionCause; +import org.apache.spark.network.shuffle.protocol.DiagnoseCorruption; import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; import org.apache.spark.network.shuffle.protocol.FetchShuffleBlocks; import org.apache.spark.network.shuffle.protocol.FetchShuffleBlockChunks; @@ -108,6 +116,111 @@ public void testCompatibilityWithOldVersion() { verifyOpenBlockLatencyMetrics(2, 2); } + private void checkDiagnosisResult( + String algorithm, + Cause expectedCaused) throws IOException { + String appId = "app0"; + String execId = "execId"; + int shuffleId = 0; + long mapId = 0; + int reduceId = 0; + + // prepare the checksum file + File tmpDir = Files.createTempDir(); + File checksumFile = new File(tmpDir, + "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".checksum." + algorithm); + DataOutputStream out = new DataOutputStream(new FileOutputStream(checksumFile)); + long checksumByReader = 0L; + if (expectedCaused != Cause.UNSUPPORTED_CHECKSUM_ALGORITHM) { + Checksum checksum = ShuffleChecksumHelper.getChecksumByAlgorithm(algorithm); + CheckedInputStream checkedIn = new CheckedInputStream( + blockMarkers[0].createInputStream(), checksum); + byte[] buffer = new byte[10]; + ByteStreams.readFully(checkedIn, buffer, 0, (int) blockMarkers[0].size()); + long checksumByWriter = checkedIn.getChecksum().getValue(); + + switch (expectedCaused) { + // when checksumByWriter != checksumRecalculated + case DISK_ISSUE: + out.writeLong(checksumByWriter - 1); + checksumByReader = checksumByWriter; + break; + + // when checksumByWriter == checksumRecalculated and checksumByReader != checksumByWriter + case NETWORK_ISSUE: + out.writeLong(checksumByWriter); + checksumByReader = checksumByWriter - 1; + break; + + case UNKNOWN_ISSUE: + // write a int instead of a long to corrupt the checksum file + out.writeInt(0); + checksumByReader = checksumByWriter; + break; + + default: + out.writeLong(checksumByWriter); + checksumByReader = checksumByWriter; + } + } + out.close(); + + when(blockResolver.getBlockData(appId, execId, shuffleId, mapId, reduceId)) + .thenReturn(blockMarkers[0]); + Cause actualCause = ShuffleChecksumHelper.diagnoseCorruption(algorithm, checksumFile, reduceId, + blockResolver.getBlockData(appId, execId, shuffleId, mapId, reduceId), checksumByReader); + when(blockResolver + .diagnoseShuffleBlockCorruption( + appId, execId, shuffleId, mapId, reduceId, checksumByReader, algorithm)) + .thenReturn(actualCause); + + when(client.getClientId()).thenReturn(appId); + RpcResponseCallback callback = mock(RpcResponseCallback.class); + + DiagnoseCorruption diagnoseMsg = new DiagnoseCorruption( + appId, execId, shuffleId, mapId, reduceId, checksumByReader, algorithm); + handler.receive(client, diagnoseMsg.toByteBuffer(), callback); + + ArgumentCaptor response = ArgumentCaptor.forClass(ByteBuffer.class); + verify(callback, times(1)).onSuccess(response.capture()); + verify(callback, never()).onFailure(any()); + + CorruptionCause cause = + (CorruptionCause) BlockTransferMessage.Decoder.fromByteBuffer(response.getValue()); + assertEquals(expectedCaused, cause.cause); + tmpDir.delete(); + } + + @Test + public void testShuffleCorruptionDiagnosisDiskIssue() throws IOException { + checkDiagnosisResult( "ADLER32", Cause.DISK_ISSUE); + } + + @Test + public void testShuffleCorruptionDiagnosisNetworkIssue() throws IOException { + checkDiagnosisResult("ADLER32", Cause.NETWORK_ISSUE); + } + + @Test + public void testShuffleCorruptionDiagnosisUnknownIssue() throws IOException { + checkDiagnosisResult("ADLER32", Cause.UNKNOWN_ISSUE); + } + + @Test + public void testShuffleCorruptionDiagnosisChecksumVerifyPass() throws IOException { + checkDiagnosisResult("ADLER32", Cause.CHECKSUM_VERIFY_PASS); + } + + @Test + public void testShuffleCorruptionDiagnosisUnSupportedAlgorithm() throws IOException { + checkDiagnosisResult("XXX", Cause.UNSUPPORTED_CHECKSUM_ALGORITHM); + } + + @Test + public void testShuffleCorruptionDiagnosisCRC32() throws IOException { + checkDiagnosisResult("CRC32", Cause.CHECKSUM_VERIFY_PASS); + } + @Test public void testFetchShuffleBlocks() { when(blockResolver.getBlockData("app0", "exec1", 0, 0, 0)).thenReturn(blockMarkers[0]); diff --git a/core/src/main/java/org/apache/spark/shuffle/checksum/ShuffleChecksumHelper.java b/core/src/main/java/org/apache/spark/shuffle/checksum/ShuffleChecksumHelper.java deleted file mode 100644 index a368836d2bb1d..0000000000000 --- a/core/src/main/java/org/apache/spark/shuffle/checksum/ShuffleChecksumHelper.java +++ /dev/null @@ -1,100 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle.checksum; - -import java.util.zip.Adler32; -import java.util.zip.CRC32; -import java.util.zip.Checksum; - -import org.apache.spark.SparkConf; -import org.apache.spark.SparkException; -import org.apache.spark.annotation.Private; -import org.apache.spark.internal.config.package$; -import org.apache.spark.storage.ShuffleChecksumBlockId; - -/** - * A set of utility functions for the shuffle checksum. - */ -@Private -public class ShuffleChecksumHelper { - - /** Used when the checksum is disabled for shuffle. */ - private static final Checksum[] EMPTY_CHECKSUM = new Checksum[0]; - public static final long[] EMPTY_CHECKSUM_VALUE = new long[0]; - - public static boolean isShuffleChecksumEnabled(SparkConf conf) { - return (boolean) conf.get(package$.MODULE$.SHUFFLE_CHECKSUM_ENABLED()); - } - - public static Checksum[] createPartitionChecksumsIfEnabled(int numPartitions, SparkConf conf) - throws SparkException { - if (!isShuffleChecksumEnabled(conf)) { - return EMPTY_CHECKSUM; - } - - String checksumAlgo = shuffleChecksumAlgorithm(conf); - return getChecksumByAlgorithm(numPartitions, checksumAlgo); - } - - private static Checksum[] getChecksumByAlgorithm(int num, String algorithm) - throws SparkException { - Checksum[] checksums; - switch (algorithm) { - case "ADLER32": - checksums = new Adler32[num]; - for (int i = 0; i < num; i ++) { - checksums[i] = new Adler32(); - } - return checksums; - - case "CRC32": - checksums = new CRC32[num]; - for (int i = 0; i < num; i ++) { - checksums[i] = new CRC32(); - } - return checksums; - - default: - throw new SparkException("Unsupported shuffle checksum algorithm: " + algorithm); - } - } - - public static long[] getChecksumValues(Checksum[] partitionChecksums) { - int numPartitions = partitionChecksums.length; - long[] checksumValues = new long[numPartitions]; - for (int i = 0; i < numPartitions; i ++) { - checksumValues[i] = partitionChecksums[i].getValue(); - } - return checksumValues; - } - - public static String shuffleChecksumAlgorithm(SparkConf conf) { - return conf.get(package$.MODULE$.SHUFFLE_CHECKSUM_ALGORITHM()); - } - - public static Checksum getChecksumByFileExtension(String fileName) throws SparkException { - int index = fileName.lastIndexOf("."); - String algorithm = fileName.substring(index + 1); - return getChecksumByAlgorithm(1, algorithm)[0]; - } - - public static String getChecksumFileName(ShuffleChecksumBlockId blockId, SparkConf conf) { - // append the shuffle checksum algorithm as the file extension - return String.format("%s.%s", blockId.name(), shuffleChecksumAlgorithm(conf)); - } -} diff --git a/core/src/main/java/org/apache/spark/shuffle/checksum/ShuffleChecksumSupport.java b/core/src/main/java/org/apache/spark/shuffle/checksum/ShuffleChecksumSupport.java new file mode 100644 index 0000000000000..4f7c3f20b4c7e --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/checksum/ShuffleChecksumSupport.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.checksum; + +import java.util.zip.Checksum; + +import org.apache.spark.SparkConf; +import org.apache.spark.internal.config.package$; +import org.apache.spark.network.shuffle.checksum.ShuffleChecksumHelper; + +public interface ShuffleChecksumSupport { + + default Checksum[] createPartitionChecksums(int numPartitions, SparkConf conf) { + if ((boolean) conf.get(package$.MODULE$.SHUFFLE_CHECKSUM_ENABLED())) { + String checksumAlgorithm = conf.get(package$.MODULE$.SHUFFLE_CHECKSUM_ALGORITHM()); + return ShuffleChecksumHelper.createPartitionChecksums(numPartitions, checksumAlgorithm); + } else { + return ShuffleChecksumHelper.EMPTY_CHECKSUM; + } + } + + default long[] getChecksumValues(Checksum[] partitionChecksums) { + int numPartitions = partitionChecksums.length; + long[] checksumValues = new long[numPartitions]; + for (int i = 0; i < numPartitions; i++) { + checksumValues[i] = partitionChecksums[i].getValue(); + } + return checksumValues; + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index 322224053df09..9a5ac6f287beb 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -40,10 +40,12 @@ import org.apache.spark.ShuffleDependency; import org.apache.spark.SparkConf; import org.apache.spark.SparkException; +import org.apache.spark.network.shuffle.checksum.ShuffleChecksumHelper; import org.apache.spark.shuffle.api.ShuffleExecutorComponents; import org.apache.spark.shuffle.api.ShuffleMapOutputWriter; import org.apache.spark.shuffle.api.ShufflePartitionWriter; import org.apache.spark.shuffle.api.WritableByteChannelWrapper; +import org.apache.spark.shuffle.checksum.ShuffleChecksumSupport; import org.apache.spark.internal.config.package$; import org.apache.spark.scheduler.MapStatus; import org.apache.spark.scheduler.MapStatus$; @@ -51,7 +53,6 @@ import org.apache.spark.serializer.SerializerInstance; import org.apache.spark.shuffle.ShuffleWriteMetricsReporter; import org.apache.spark.shuffle.ShuffleWriter; -import org.apache.spark.shuffle.checksum.ShuffleChecksumHelper; import org.apache.spark.storage.*; import org.apache.spark.util.Utils; @@ -76,7 +77,9 @@ *

* There have been proposals to completely remove this code path; see SPARK-6026 for details. */ -final class BypassMergeSortShuffleWriter extends ShuffleWriter { +final class BypassMergeSortShuffleWriter + extends ShuffleWriter + implements ShuffleChecksumSupport { private static final Logger logger = LoggerFactory.getLogger(BypassMergeSortShuffleWriter.class); @@ -125,8 +128,7 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter { this.writeMetrics = writeMetrics; this.serializer = dep.serializer(); this.shuffleExecutorComponents = shuffleExecutorComponents; - this.partitionChecksums = - ShuffleChecksumHelper.createPartitionChecksumsIfEnabled(numPartitions, conf); + this.partitionChecksums = createPartitionChecksums(numPartitions, conf); } @Override @@ -230,9 +232,8 @@ private long[] writePartitionedData(ShuffleMapOutputWriter mapOutputWriter) thro } partitionWriters = null; } - return mapOutputWriter.commitAllPartitions( - ShuffleChecksumHelper.getChecksumValues(partitionChecksums) - ).getPartitionLengths(); + return mapOutputWriter.commitAllPartitions(getChecksumValues(partitionChecksums)) + .getPartitionLengths(); } private void writePartitionedDataWithChannel( diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java index 0307027c6f264..a82f691d085d4 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java @@ -41,7 +41,7 @@ import org.apache.spark.serializer.DummySerializerInstance; import org.apache.spark.serializer.SerializerInstance; import org.apache.spark.shuffle.ShuffleWriteMetricsReporter; -import org.apache.spark.shuffle.checksum.ShuffleChecksumHelper; +import org.apache.spark.shuffle.checksum.ShuffleChecksumSupport; import org.apache.spark.storage.BlockManager; import org.apache.spark.storage.DiskBlockObjectWriter; import org.apache.spark.storage.FileSegment; @@ -68,7 +68,7 @@ * spill files. Instead, this merging is performed in {@link UnsafeShuffleWriter}, which uses a * specialized merge procedure that avoids extra serialization/deserialization. */ -final class ShuffleExternalSorter extends MemoryConsumer { +final class ShuffleExternalSorter extends MemoryConsumer implements ShuffleChecksumSupport { private static final Logger logger = LoggerFactory.getLogger(ShuffleExternalSorter.class); @@ -139,12 +139,11 @@ final class ShuffleExternalSorter extends MemoryConsumer { this.peakMemoryUsedBytes = getMemoryUsage(); this.diskWriteBufferSize = (int) (long) conf.get(package$.MODULE$.SHUFFLE_DISK_WRITE_BUFFER_SIZE()); - this.partitionChecksums = - ShuffleChecksumHelper.createPartitionChecksumsIfEnabled(numPartitions, conf); + this.partitionChecksums = createPartitionChecksums(numPartitions, conf); } public long[] getChecksums() { - return ShuffleChecksumHelper.getChecksumValues(partitionChecksums); + return getChecksumValues(partitionChecksums); } /** diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index 2659b172bf68c..b1779a135b786 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -45,6 +45,7 @@ import org.apache.spark.io.CompressionCodec$; import org.apache.spark.io.NioBufferedFileInputStream; import org.apache.spark.memory.TaskMemoryManager; +import org.apache.spark.network.shuffle.checksum.ShuffleChecksumHelper; import org.apache.spark.network.util.LimitedInputStream; import org.apache.spark.scheduler.MapStatus; import org.apache.spark.scheduler.MapStatus$; @@ -57,7 +58,6 @@ import org.apache.spark.shuffle.api.ShufflePartitionWriter; import org.apache.spark.shuffle.api.SingleSpillShuffleMapOutputWriter; import org.apache.spark.shuffle.api.WritableByteChannelWrapper; -import org.apache.spark.shuffle.checksum.ShuffleChecksumHelper; import org.apache.spark.storage.BlockManager; import org.apache.spark.storage.TimeTrackingOutputStream; import org.apache.spark.unsafe.Platform; diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 39c526cb0e8b3..60ba3aac264a5 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -1372,7 +1372,7 @@ package object config { ConfigBuilder("spark.shuffle.checksum.enabled") .doc("Whether to calculate the checksum of shuffle output. If enabled, Spark will try " + "its best to tell if shuffle data corruption is caused by network or disk or others.") - .version("3.3.0") + .version("3.2.0") .booleanConf .createWithDefault(true) @@ -1380,7 +1380,7 @@ package object config { ConfigBuilder("spark.shuffle.checksum.algorithm") .doc("The algorithm used to calculate the checksum. Currently, it only supports" + " built-in algorithms of JDK.") - .version("3.3.0") + .version("3.2.0") .stringConf .transform(_.toUpperCase(Locale.ROOT)) .checkValue(Set("ADLER32", "CRC32").contains, "Shuffle checksum algorithm " + diff --git a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala index cafb39ea82ad9..89177346a789a 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala @@ -22,11 +22,20 @@ import scala.reflect.ClassTag import org.apache.spark.TaskContext import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.client.StreamCallbackWithID +import org.apache.spark.network.shuffle.checksum.Cause import org.apache.spark.storage.{BlockId, StorageLevel} private[spark] trait BlockDataManager { + /** + * Diagnose the possible cause of the shuffle data corruption by verifying the shuffle checksums + */ + def diagnoseShuffleBlockCorruption( + blockId: BlockId, + checksumByReader: Long, + algorithm: String): Cause + /** * Get the local directories that used by BlockManager to save the blocks to disk */ diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala index 5f831dc666ca5..81c878d17c695 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala @@ -133,6 +133,13 @@ class NettyBlockRpcServer( Map(actualExecId -> blockManager.getLocalDiskDirs).asJava).toByteBuffer) } } + + case diagnose: DiagnoseCorruption => + val cause = blockManager.diagnoseShuffleBlockCorruption( + ShuffleBlockId(diagnose.shuffleId, diagnose.mapId, diagnose.reduceId ), + diagnose.checksum, + diagnose.algorithm) + responseContext.onSuccess(new CorruptionCause(cause).toByteBuffer) } } diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index 4e0beeaec97ad..6da0cb439db1a 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -61,7 +61,6 @@ private[spark] class NettyBlockTransferService( // TODO: Don't use Java serialization, use a more cross-version compatible serialization format. private val serializer = new JavaSerializer(conf) private val authEnabled = securityManager.isAuthenticationEnabled() - private val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numCores) private[this] var transportContext: TransportContext = _ private[this] var server: TransportServer = _ @@ -70,6 +69,7 @@ private[spark] class NettyBlockTransferService( val rpcHandler = new NettyBlockRpcServer(conf.getAppId, serializer, blockDataManager) var serverBootstrap: Option[TransportServerBootstrap] = None var clientBootstrap: Option[TransportClientBootstrap] = None + this.transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numCores) if (authEnabled) { serverBootstrap = Some(new AuthServerBootstrap(transportConf, securityManager)) clientBootstrap = Some(new AuthClientBootstrap(transportConf, conf.getAppId, securityManager)) @@ -78,6 +78,7 @@ private[spark] class NettyBlockTransferService( clientFactory = transportContext.createClientFactory(clientBootstrap.toSeq.asJava) server = createServer(serverBootstrap.toList) appId = conf.getAppId + logger.info(s"Server created on $hostName:${server.getPort}") } diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 818aa2ef75a9e..df06b07852905 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -83,6 +83,8 @@ private[spark] class BlockStoreShuffleReader[K, C]( SparkEnv.get.conf.get(config.SHUFFLE_MAX_ATTEMPTS_ON_NETTY_OOM), SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT), SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT_MEMORY), + SparkEnv.get.conf.get(config.SHUFFLE_CHECKSUM_ENABLED), + SparkEnv.get.conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM), readMetrics, fetchContinuousBlocksInBatch).toCompletionIterator diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index 07928f8c52252..7454a74094541 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -31,9 +31,9 @@ import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.client.StreamCallbackWithID import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle.{ExecutorDiskUtils, MergedBlockMeta} +import org.apache.spark.network.shuffle.checksum.ShuffleChecksumHelper import org.apache.spark.serializer.SerializerManager import org.apache.spark.shuffle.IndexShuffleBlockResolver.NOOP_REDUCE_ID -import org.apache.spark.shuffle.checksum.ShuffleChecksumHelper import org.apache.spark.storage._ import org.apache.spark.util.Utils @@ -157,7 +157,7 @@ private[spark] class IndexShuffleBlockResolver( logWarning(s"Error deleting index ${file.getPath()}") } - file = getChecksumFile(shuffleId, mapId) + file = getChecksumFile(shuffleId, mapId, conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM)) if (file.exists() && !file.delete()) { logWarning(s"Error deleting checksum ${file.getPath()}") } @@ -339,7 +339,8 @@ private[spark] class IndexShuffleBlockResolver( val (checksumFileOpt, checksumTmpOpt) = if (checksumEnabled) { assert(lengths.length == checksums.length, "The size of partition lengths and checksums should be equal") - val checksumFile = getChecksumFile(shuffleId, mapId) + val checksumFile = + getChecksumFile(shuffleId, mapId, conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM)) (Some(checksumFile), Some(Utils.tempFileWith(checksumFile))) } else { (None, None) @@ -540,14 +541,13 @@ private[spark] class IndexShuffleBlockResolver( def getChecksumFile( shuffleId: Int, mapId: Long, + algorithm: String, dirs: Option[Array[String]] = None): File = { val blockId = ShuffleChecksumBlockId(shuffleId, mapId, NOOP_REDUCE_ID) - val fileName = ShuffleChecksumHelper.getChecksumFileName(blockId, conf) + val fileName = ShuffleChecksumHelper.getChecksumFileName(blockId.name, algorithm) dirs .map(ExecutorDiskUtils.getFile(_, blockManager.subDirsPerLocalDir, fileName)) - .getOrElse { - blockManager.diskBlockManager.getFile(fileName) - } + .getOrElse(blockManager.diskBlockManager.getFile(fileName)) } override def getBlockData( diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala index ce53f08bae8ee..e450129be98f0 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala @@ -94,7 +94,7 @@ case class ShuffleIndexBlockId(shuffleId: Int, mapId: Long, reduceId: Int) exten override def name: String = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".index" } -@Since("3.3.0") +@Since("3.2.0") @DeveloperApi case class ShuffleChecksumBlockId(shuffleId: Int, mapId: Long, reduceId: Int) extends BlockId { override def name: String = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".checksum" diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index b81b3b60520c1..4c646b27c270f 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -49,12 +49,13 @@ import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.client.StreamCallbackWithID import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle._ +import org.apache.spark.network.shuffle.checksum.{Cause, ShuffleChecksumHelper} import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo import org.apache.spark.network.util.TransportConf import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.ExecutorCacheTaskLocation import org.apache.spark.serializer.{SerializerInstance, SerializerManager} -import org.apache.spark.shuffle.{MigratableResolver, ShuffleManager, ShuffleWriteMetricsReporter} +import org.apache.spark.shuffle.{IndexShuffleBlockResolver, MigratableResolver, ShuffleManager, ShuffleWriteMetricsReporter} import org.apache.spark.storage.BlockManagerMessages.{DecommissionBlockManager, ReplicateBlock} import org.apache.spark.storage.memory._ import org.apache.spark.unsafe.Platform @@ -282,6 +283,28 @@ private[spark] class BlockManager( override def getLocalDiskDirs: Array[String] = diskBlockManager.localDirsString + /** + * Diagnose the possible cause of the shuffle data corruption by verifying the shuffle checksums + * + * @param blockId The blockId of the corrupted shuffle block + * @param checksumByReader The checksum value of the corrupted block + * @param algorithm The cheksum algorithm that is used when calculating the checksum value + */ + override def diagnoseShuffleBlockCorruption( + blockId: BlockId, + checksumByReader: Long, + algorithm: String): Cause = { + assert(blockId.isInstanceOf[ShuffleBlockId], + s"Corruption diagnosis only supports shuffle block yet, but got $blockId") + val shuffleBlock = blockId.asInstanceOf[ShuffleBlockId] + val resolver = shuffleManager.shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver] + val checksumFile = + resolver.getChecksumFile(shuffleBlock.shuffleId, shuffleBlock.mapId, algorithm) + val reduceId = shuffleBlock.reduceId + ShuffleChecksumHelper.diagnoseCorruption( + algorithm, checksumFile, reduceId, resolver.getBlockData(shuffleBlock), checksumByReader) + } + /** * Abstraction for storing blocks from bytes, whether they start in memory or on disk. * diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index fd87f5e568d0c..3eb8acd4f5560 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -21,6 +21,7 @@ import java.io.{InputStream, IOException} import java.nio.channels.ClosedByInterruptException import java.util.concurrent.{LinkedBlockingQueue, TimeUnit} import java.util.concurrent.atomic.AtomicBoolean +import java.util.zip.CheckedInputStream import javax.annotation.concurrent.GuardedBy import scala.collection.mutable @@ -36,6 +37,7 @@ import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID import org.apache.spark.internal.Logging import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.shuffle._ +import org.apache.spark.network.shuffle.checksum.{Cause, ShuffleChecksumHelper} import org.apache.spark.network.util.{NettyUtils, TransportConf} import org.apache.spark.shuffle.{FetchFailedException, ShuffleReadMetricsReporter} import org.apache.spark.util.{CompletionIterator, TaskCompletionListener, Utils} @@ -69,7 +71,11 @@ import org.apache.spark.util.{CompletionIterator, TaskCompletionListener, Utils} * @param maxReqSizeShuffleToMem max size (in bytes) of a request that can be shuffled to memory. * @param maxAttemptsOnNettyOOM The max number of a block could retry due to Netty OOM before * throwing the fetch failure. - * @param detectCorrupt whether to detect any corruption in fetched blocks. + * @param detectCorrupt whether to detect any corruption in fetched blocks. + * @param checksumEnabled whether the shuffle checksum is enabled. When enabled, Spark will try to + * diagnose the cause of the block corruption. + * @param checksumAlgorithm the checksum algorithm that is used when calculating the checksum value + * for the block data. * @param shuffleMetrics used to report shuffle metrics. * @param doBatchFetch fetch continuous shuffle blocks from same executor in batch if the server * side supports. @@ -89,6 +95,8 @@ final class ShuffleBlockFetcherIterator( maxAttemptsOnNettyOOM: Int, detectCorrupt: Boolean, detectCorruptUseExtraMemory: Boolean, + checksumEnabled: Boolean, + checksumAlgorithm: String, shuffleMetrics: ShuffleReadMetricsReporter, doBatchFetch: Boolean) extends Iterator[(BlockId, InputStream)] with DownloadFileManager with Logging { @@ -732,6 +740,8 @@ final class ShuffleBlockFetcherIterator( var result: FetchResult = null var input: InputStream = null + // This's only initialized when shuffle checksum is enabled. + var checkedIn: CheckedInputStream = null var streamCompressedOrEncrypted: Boolean = false // Take the next fetched result and try to decompress it to detect data corruption, // then fetch it one more time if it's corrupt, throw FailureFetchResult if the second fetch @@ -787,7 +797,14 @@ final class ShuffleBlockFetcherIterator( } val in = try { - buf.createInputStream() + var bufIn = buf.createInputStream() + if (checksumEnabled) { + val checksum = ShuffleChecksumHelper.getChecksumByAlgorithm(checksumAlgorithm) + checkedIn = new CheckedInputStream(bufIn, checksum) + checkedIn + } else { + bufIn + } } catch { // The exception could only be throwed by local shuffle block case e: IOException => @@ -822,8 +839,15 @@ final class ShuffleBlockFetcherIterator( } } catch { case e: IOException => - buf.release() + // When shuffle checksum is enabled, for a block that is corrupted twice, + // we'd calculate the checksum of the block by consuming the remaining data + // in the buf. So, we should release the buf later. + if (!(checksumEnabled && corruptedBlocks.contains(blockId))) { + buf.release() + } + if (blockId.isShuffleChunk) { + // TODO (SPARK-36284): Add shuffle checksum support for push-based shuffle // Retrying a corrupt block may result again in a corrupt block. For shuffle // chunks, we opt to fallback on the original shuffle blocks that belong to that // corrupt shuffle chunk immediately instead of retrying to fetch the corrupt @@ -834,17 +858,27 @@ final class ShuffleBlockFetcherIterator( pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock(blockId, address) // Set result to null to trigger another iteration of the while loop. result = null - } else { - if (buf.isInstanceOf[FileSegmentManagedBuffer] - || corruptedBlocks.contains(blockId)) { - throwFetchFailedException(blockId, mapIndex, address, e) + } else if (buf.isInstanceOf[FileSegmentManagedBuffer]) { + throwFetchFailedException(blockId, mapIndex, address, e) + } else if (corruptedBlocks.contains(blockId)) { + // It's the second time this block is detected corrupted + if (checksumEnabled) { + // Diagnose the cause of data corruption if shuffle checksum is enabled + val diagnosisResponse = diagnoseCorruption(checkedIn, address, blockId) + buf.release() + logError(diagnosisResponse) + throwFetchFailedException( + blockId, mapIndex, address, e, Some(diagnosisResponse)) } else { - logWarning(s"got an corrupted block $blockId from $address, fetch again", e) - corruptedBlocks += blockId - fetchRequests += FetchRequest( - address, Array(FetchBlockInfo(blockId, size, mapIndex))) - result = null + throwFetchFailedException(blockId, mapIndex, address, e) } + } else { + // It's the first time this block is detected corrupted + logWarning(s"got an corrupted block $blockId from $address, fetch again", e) + corruptedBlocks += blockId + fetchRequests += FetchRequest( + address, Array(FetchBlockInfo(blockId, size, mapIndex))) + result = null } } finally { if (blockId.isShuffleChunk) { @@ -975,7 +1009,66 @@ final class ShuffleBlockFetcherIterator( currentResult.mapIndex, currentResult.address, detectCorrupt && streamCompressedOrEncrypted, - currentResult.isNetworkReqDone)) + currentResult.isNetworkReqDone, + Option(checkedIn))) + } + + /** + * Get the suspect corruption cause for the corrupted block. It should be only invoked + * when checksum is enabled and corruption was detected at least once. + * + * This will firstly consume the rest of stream of the corrupted block to calculate the + * checksum of the block. Then, it will raise a synchronized RPC call along with the + * checksum to ask the server(where the corrupted block is fetched from) to diagnose the + * cause of corruption and return it. + * + * Any exception raised during the process will result in the [[Cause.UNKNOWN_ISSUE]] of the + * corruption cause since corruption diagnosis is only a best effort. + * + * @param checkedIn the [[CheckedInputStream]] which is used to calculate the checksum. + * @param address the address where the corrupted block is fetched from. + * @param blockId the blockId of the corrupted block. + * @return The corruption diagnosis response for different causes. + */ + private[storage] def diagnoseCorruption( + checkedIn: CheckedInputStream, + address: BlockManagerId, + blockId: BlockId): String = { + logInfo("Start corruption diagnosis.") + val startTimeNs = System.nanoTime() + assert(blockId.isInstanceOf[ShuffleBlockId], s"Expected ShuffleBlockId, but got $blockId") + val shuffleBlock = blockId.asInstanceOf[ShuffleBlockId] + val buffer = new Array[Byte](ShuffleChecksumHelper.CHECKSUM_CALCULATION_BUFFER) + // consume the remaining data to calculate the checksum + var cause: Cause = null + try { + while (checkedIn.read(buffer) != -1) {} + val checksum = checkedIn.getChecksum.getValue + cause = shuffleClient.diagnoseCorruption(address.host, address.port, address.executorId, + shuffleBlock.shuffleId, shuffleBlock.mapId, shuffleBlock.reduceId, checksum, + checksumAlgorithm) + } catch { + case e: Exception => + logWarning("Unable to diagnose the corruption cause of the corrupted block", e) + cause = Cause.UNKNOWN_ISSUE + } + val duration = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs) + val diagnosisResponse = cause match { + case Cause.UNSUPPORTED_CHECKSUM_ALGORITHM => + s"Block $blockId is corrupted but corruption diagnosis failed due to " + + s"unsupported checksum algorithm: $checksumAlgorithm" + + case Cause.CHECKSUM_VERIFY_PASS => + s"Block $blockId is corrupted but checksum verification passed" + + case Cause.UNKNOWN_ISSUE => + s"Block $blockId is corrupted but the cause is unknown" + + case otherCause => + s"Block $blockId is corrupted due to $otherCause" + } + logInfo(s"Finished corruption diagnosis in $duration ms. $diagnosisResponse") + diagnosisResponse } def toCompletionIterator: Iterator[(BlockId, InputStream)] = { @@ -1158,7 +1251,8 @@ private class BufferReleasingInputStream( private val mapIndex: Int, private val address: BlockManagerId, private val detectCorruption: Boolean, - private val isNetworkReqDone: Boolean) + private val isNetworkReqDone: Boolean, + private val checkedInOpt: Option[CheckedInputStream]) extends InputStream { private[this] var closed = false @@ -1207,8 +1301,13 @@ private class BufferReleasingInputStream( block } catch { case e: IOException if detectCorruption => + val diagnosisResponse = checkedInOpt.map { checkedIn => + iterator.diagnoseCorruption(checkedIn, address, blockId) + } IOUtils.closeQuietly(this) - iterator.throwFetchFailedException(blockId, mapIndex, address, e) + // We'd never retry the block whatever the cause is since the block has been + // partially consumed by downstream RDDs. + iterator.throwFetchFailedException(blockId, mapIndex, address, e, diagnosisResponse) } } } diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index c63e196ddc814..eda408afa7ce5 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -31,7 +31,7 @@ import org.apache.spark.internal.{config, Logging} import org.apache.spark.serializer._ import org.apache.spark.shuffle.ShufflePartitionPairsWriter import org.apache.spark.shuffle.api.{ShuffleMapOutputWriter, ShufflePartitionWriter} -import org.apache.spark.shuffle.checksum.ShuffleChecksumHelper +import org.apache.spark.shuffle.checksum.ShuffleChecksumSupport import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter, ShuffleBlockId} import org.apache.spark.util.{CompletionIterator, Utils => TryUtils} @@ -97,7 +97,7 @@ private[spark] class ExternalSorter[K, V, C]( ordering: Option[Ordering[K]] = None, serializer: Serializer = SparkEnv.get.serializer) extends Spillable[WritablePartitionedPairCollection[K, C]](context.taskMemoryManager()) - with Logging { + with Logging with ShuffleChecksumSupport { private val conf = SparkEnv.get.conf @@ -142,10 +142,9 @@ private[spark] class ExternalSorter[K, V, C]( private val forceSpillFiles = new ArrayBuffer[SpilledFile] @volatile private var readingIterator: SpillableIterator = null - private val partitionChecksums = - ShuffleChecksumHelper.createPartitionChecksumsIfEnabled(numPartitions, conf) + private val partitionChecksums = createPartitionChecksums(numPartitions, conf) - def getChecksums: Array[Long] = ShuffleChecksumHelper.getChecksumValues(partitionChecksums) + def getChecksums: Array[Long] = getChecksumValues(partitionChecksums) // A comparator for keys K that orders them within a partition to allow aggregation or sorting. // Can be a partial ordering by hash code if a total ordering is not provided through by the diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index 63220ed49f56c..87f9ab32eb585 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -23,8 +23,8 @@ import java.util.*; import org.apache.spark.*; +import org.apache.spark.network.shuffle.checksum.ShuffleChecksumHelper; import org.apache.spark.shuffle.ShuffleChecksumTestHelper; -import org.apache.spark.shuffle.checksum.ShuffleChecksumHelper; import org.mockito.stubbing.Answer; import scala.*; import scala.collection.Iterator; @@ -38,11 +38,11 @@ import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.executor.TaskMetrics; +import org.apache.spark.internal.config.package$; import org.apache.spark.io.CompressionCodec$; import org.apache.spark.io.LZ4CompressionCodec; import org.apache.spark.io.LZFCompressionCodec; import org.apache.spark.io.SnappyCompressionCodec; -import org.apache.spark.internal.config.package$; import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.memory.TestMemoryManager; import org.apache.spark.network.util.LimitedInputStream; @@ -301,12 +301,13 @@ public void writeChecksumFileWithoutSpill() throws Exception { IndexShuffleBlockResolver blockResolver = new IndexShuffleBlockResolver(conf, blockManager); ShuffleChecksumBlockId checksumBlockId = new ShuffleChecksumBlockId(0, 0, IndexShuffleBlockResolver.NOOP_REDUCE_ID()); - File checksumFile = new File(tempDir, - ShuffleChecksumHelper.getChecksumFileName(checksumBlockId, conf)); + String checksumAlgorithm = conf.get(package$.MODULE$.SHUFFLE_CHECKSUM_ALGORITHM()); + String checksumFileName = ShuffleChecksumHelper.getChecksumFileName( + checksumBlockId.name(), checksumAlgorithm); + File checksumFile = new File(tempDir, checksumFileName); File dataFile = new File(tempDir, "data"); File indexFile = new File(tempDir, "index"); - when(diskBlockManager.getFile(checksumFile.getName())) - .thenReturn(checksumFile); + when(diskBlockManager.getFile(checksumFileName)).thenReturn(checksumFile); when(diskBlockManager.getFile(new ShuffleDataBlockId(shuffleDep.shuffleId(), 0, 0))) .thenReturn(dataFile); when(diskBlockManager.getFile(new ShuffleIndexBlockId(shuffleDep.shuffleId(), 0, 0))) @@ -322,7 +323,7 @@ public void writeChecksumFileWithoutSpill() throws Exception { writer1.stop(true); assertTrue(checksumFile.exists()); assertEquals(checksumFile.length(), 8 * NUM_PARTITIONS); - compareChecksums(NUM_PARTITIONS, checksumFile, dataFile, indexFile); + compareChecksums(NUM_PARTITIONS, checksumAlgorithm, checksumFile, dataFile, indexFile); } @Test @@ -330,11 +331,13 @@ public void writeChecksumFileWithSpill() throws Exception { IndexShuffleBlockResolver blockResolver = new IndexShuffleBlockResolver(conf, blockManager); ShuffleChecksumBlockId checksumBlockId = new ShuffleChecksumBlockId(0, 0, IndexShuffleBlockResolver.NOOP_REDUCE_ID()); - File checksumFile = - new File(tempDir, ShuffleChecksumHelper.getChecksumFileName(checksumBlockId, conf)); + String checksumAlgorithm = conf.get(package$.MODULE$.SHUFFLE_CHECKSUM_ALGORITHM()); + String checksumFileName = ShuffleChecksumHelper.getChecksumFileName( + checksumBlockId.name(), checksumAlgorithm); + File checksumFile = new File(tempDir, checksumFileName); File dataFile = new File(tempDir, "data"); File indexFile = new File(tempDir, "index"); - when(diskBlockManager.getFile(eq(checksumFile.getName()))).thenReturn(checksumFile); + when(diskBlockManager.getFile(checksumFileName)).thenReturn(checksumFile); when(diskBlockManager.getFile(new ShuffleDataBlockId(shuffleDep.shuffleId(), 0, 0))) .thenReturn(dataFile); when(diskBlockManager.getFile(new ShuffleIndexBlockId(shuffleDep.shuffleId(), 0, 0))) @@ -356,7 +359,7 @@ public void writeChecksumFileWithSpill() throws Exception { writer1.closeAndWriteOutput(); assertTrue(checksumFile.exists()); assertEquals(checksumFile.length(), 8 * NUM_PARTITIONS); - compareChecksums(NUM_PARTITIONS, checksumFile, dataFile, indexFile); + compareChecksums(NUM_PARTITIONS, checksumAlgorithm, checksumFile, dataFile, indexFile); } private void testMergingSpills( diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index fd75d91d8dd2b..c1a964c336109 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark -import java.io.File +import java.io.{File, RandomAccessFile} import java.util.{Locale, Properties} import java.util.concurrent.{Callable, CyclicBarrier, Executors, ExecutorService } @@ -447,6 +447,41 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalRootDi } } } + + test("SPARK-36206: shuffle checksum detect disk corruption") { + val newConf = conf.clone + .set(config.SHUFFLE_CHECKSUM_ENABLED, true) + .set(TEST_NO_STAGE_RETRY, false) + .set("spark.stage.maxConsecutiveAttempts", "1") + sc = new SparkContext("local-cluster[2, 1, 2048]", "test", newConf) + val rdd = sc.parallelize(1 to 10, 2).map((_, 1)).reduceByKey(_ + _) + // materialize the shuffle map outputs + rdd.count() + + sc.parallelize(1 to 10, 2).barrier().mapPartitions { iter => + var dataFile = SparkEnv.get.blockManager + .diskBlockManager.getFile(ShuffleDataBlockId(0, 0, 0)) + if (!dataFile.exists()) { + dataFile = SparkEnv.get.blockManager + .diskBlockManager.getFile(ShuffleDataBlockId(0, 1, 0)) + } + + if (dataFile.exists()) { + val f = new RandomAccessFile(dataFile, "rw") + // corrupt the shuffle data files by writing some arbitrary bytes + f.seek(0) + f.write(Array[Byte](12)) + f.close() + } + BarrierTaskContext.get().barrier() + iter + }.collect() + + val e = intercept[SparkException] { + rdd.count() + } + assert(e.getMessage.contains("corrupted due to DISK_ISSUE")) + } } /** diff --git a/core/src/test/scala/org/apache/spark/shuffle/ShuffleChecksumTestHelper.scala b/core/src/test/scala/org/apache/spark/shuffle/ShuffleChecksumTestHelper.scala index a8f2c4088c422..3db2f77fe1534 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/ShuffleChecksumTestHelper.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/ShuffleChecksumTestHelper.scala @@ -20,15 +20,20 @@ package org.apache.spark.shuffle import java.io.{DataInputStream, File, FileInputStream} import java.util.zip.CheckedInputStream +import org.apache.spark.network.shuffle.checksum.ShuffleChecksumHelper import org.apache.spark.network.util.LimitedInputStream -import org.apache.spark.shuffle.checksum.ShuffleChecksumHelper trait ShuffleChecksumTestHelper { /** * Ensure that the checksum values are consistent between write and read side. */ - def compareChecksums(numPartition: Int, checksum: File, data: File, index: File): Unit = { + def compareChecksums( + numPartition: Int, + algorithm: String, + checksum: File, + data: File, + index: File): Unit = { assert(checksum.exists(), "Checksum file doesn't exist") assert(data.exists(), "Data file doesn't exist") assert(index.exists(), "Index file doesn't exist") @@ -55,7 +60,7 @@ trait ShuffleChecksumTestHelper { val curOffset = indexIn.readLong val limit = (curOffset - prevOffset).toInt val bytes = new Array[Byte](limit) - val checksumCal = ShuffleChecksumHelper.getChecksumByFileExtension(checksum.getName) + val checksumCal = ShuffleChecksumHelper.getChecksumByAlgorithm(algorithm) checkedIn = new CheckedInputStream( new LimitedInputStream(dataIn, curOffset - prevOffset), checksumCal) checkedIn.read(bytes, 0, limit) diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala index 39eef9749eac3..38ed702d0e4c7 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala @@ -31,11 +31,12 @@ import org.scalatest.BeforeAndAfterEach import org.apache.spark._ import org.apache.spark.executor.{ShuffleWriteMetrics, TaskMetrics} +import org.apache.spark.internal.config import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager} +import org.apache.spark.network.shuffle.checksum.ShuffleChecksumHelper import org.apache.spark.serializer.{JavaSerializer, SerializerInstance, SerializerManager} import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleChecksumTestHelper} import org.apache.spark.shuffle.api.ShuffleExecutorComponents -import org.apache.spark.shuffle.checksum.ShuffleChecksumHelper import org.apache.spark.shuffle.sort.io.LocalDiskShuffleExecutorComponents import org.apache.spark.storage._ import org.apache.spark.util.Utils @@ -248,12 +249,14 @@ class BypassMergeSortShuffleWriterSuite val checksumBlockId = ShuffleChecksumBlockId(shuffleId, mapId, 0) val dataBlockId = ShuffleDataBlockId(shuffleId, mapId, 0) val indexBlockId = ShuffleIndexBlockId(shuffleId, mapId, 0) - val checksumFile = new File(tempDir, - ShuffleChecksumHelper.getChecksumFileName(checksumBlockId, conf)) + val checksumAlgorithm = conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM) + val checksumFileName = ShuffleChecksumHelper.getChecksumFileName( + checksumBlockId.name, checksumAlgorithm) + val checksumFile = new File(tempDir, checksumFileName) val dataFile = new File(tempDir, dataBlockId.name) val indexFile = new File(tempDir, indexBlockId.name) reset(diskBlockManager) - when(diskBlockManager.getFile(checksumFile.getName)).thenAnswer(_ => checksumFile) + when(diskBlockManager.getFile(checksumFileName)).thenAnswer(_ => checksumFile) when(diskBlockManager.getFile(dataBlockId)).thenAnswer(_ => dataFile) when(diskBlockManager.getFile(indexBlockId)).thenAnswer(_ => indexFile) when(diskBlockManager.createTempShuffleBlock()) @@ -277,6 +280,6 @@ class BypassMergeSortShuffleWriterSuite writer.stop( /* success = */ true) assert(checksumFile.exists()) assert(checksumFile.length() === 8 * numPartition) - compareChecksums(numPartition, checksumFile, dataFile, indexFile) + compareChecksums(numPartition, checksumAlgorithm, checksumFile, dataFile, indexFile) } } diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala index abe2b5694bef5..21704b1c67325 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala @@ -262,7 +262,7 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa val indexInMemory = Array[Long](0, 1, 2, 3, 4, 5, 6, 7, 8, 9) val checksumsInMemory = Array[Long](0, 1, 2, 3, 4, 5, 6, 7, 8, 9) resolver.writeMetadataFileAndCommit(0, 0, indexInMemory, checksumsInMemory, dataTmp) - val checksumFile = resolver.getChecksumFile(0, 0) + val checksumFile = resolver.getChecksumFile(0, 0, conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM)) assert(checksumFile.exists()) val checksumFileName = checksumFile.toString val checksumAlgo = checksumFileName.substring(checksumFileName.lastIndexOf(".") + 1) diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala index e3457367d9baf..6c13c7c8c3c61 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala @@ -24,6 +24,7 @@ import org.scalatest.PrivateMethodTester import org.scalatest.matchers.must.Matchers import org.apache.spark.{Aggregator, DebugFilesystem, Partitioner, SharedSparkContext, ShuffleDependency, SparkContext, SparkFunSuite} +import org.apache.spark.internal.config import org.apache.spark.memory.MemoryTestingUtils import org.apache.spark.serializer.JavaSerializer import org.apache.spark.shuffle.{BaseShuffleHandle, IndexShuffleBlockResolver, ShuffleChecksumTestHelper} @@ -165,12 +166,13 @@ class SortShuffleWriterSuite val expectSpillSize = if (doSpill) records.size else 0 assert(sorter.numSpills === expectSpillSize) writer.stop(success = true) - val checksumFile = shuffleBlockResolver.getChecksumFile(shuffleId, 0) + val checksumAlgorithm = conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM) + val checksumFile = shuffleBlockResolver.getChecksumFile(shuffleId, 0, checksumAlgorithm) assert(checksumFile.exists()) assert(checksumFile.length() === 8 * numPartition) val dataFile = shuffleBlockResolver.getDataFile(shuffleId, 0) val indexFile = shuffleBlockResolver.getIndexFile(shuffleId, 0) - compareChecksums(numPartition, checksumFile, dataFile, indexFile) + compareChecksums(numPartition, checksumAlgorithm, checksumFile, dataFile, indexFile) localSC.stop() } } diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index c22e1d0ca2244..8ed009882bdd1 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -21,11 +21,13 @@ import java.io._ import java.nio.ByteBuffer import java.util.UUID import java.util.concurrent.{CompletableFuture, Semaphore} +import java.util.zip.CheckedInputStream import scala.collection.mutable import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.Future +import com.google.common.io.ByteStreams import io.netty.util.internal.OutOfDirectMemoryError import org.apache.log4j.Level import org.mockito.ArgumentMatchers.{any, eq => meq} @@ -157,14 +159,19 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val wrappedInputStream = inputStream.asInstanceOf[BufferReleasingInputStream] verify(buffer, times(0)).release() val delegateAccess = PrivateMethod[InputStream](Symbol("delegate")) - - verify(wrappedInputStream.invokePrivate(delegateAccess()), times(0)).close() + var in = wrappedInputStream.invokePrivate(delegateAccess()) + if (in.isInstanceOf[CheckedInputStream]) { + val underlyingInputFiled = classOf[CheckedInputStream].getSuperclass.getDeclaredField("in") + underlyingInputFiled.setAccessible(true) + in = underlyingInputFiled.get(in.asInstanceOf[CheckedInputStream]).asInstanceOf[InputStream] + } + verify(in, times(0)).close() wrappedInputStream.close() verify(buffer, times(1)).release() - verify(wrappedInputStream.invokePrivate(delegateAccess()), times(1)).close() + verify(in, times(1)).close() wrappedInputStream.close() // close should be idempotent verify(buffer, times(1)).release() - verify(wrappedInputStream.invokePrivate(delegateAccess()), times(1)).close() + verify(in, times(1)).close() } // scalastyle:off argcount @@ -180,6 +187,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT maxAttemptsOnNettyOOM: Int = 10, detectCorrupt: Boolean = true, detectCorruptUseExtraMemory: Boolean = true, + checksumEnabled: Boolean = true, + checksumAlgorithm: String = "ADLER32", shuffleMetrics: Option[ShuffleReadMetricsReporter] = None, doBatchFetch: Boolean = false): ShuffleBlockFetcherIterator = { val tContext = taskContext.getOrElse(TaskContext.empty()) @@ -197,6 +206,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT maxAttemptsOnNettyOOM, detectCorrupt, detectCorruptUseExtraMemory, + checksumEnabled, + checksumAlgorithm, shuffleMetrics.getOrElse(tContext.taskMetrics().createTempShuffleReadMetrics()), doBatchFetch) } @@ -213,6 +224,69 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT blockIds.map(blockId => (blockId, blockSize, blockMapIndex)).toSeq } + test("SPARK-36206: diagnose the block when it's corrupted twice") { + // Make sure remote blocks would return + val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) + val blocks = Map[BlockId, ManagedBuffer]( + ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer() + ) + answerFetchBlocks { invocation => + val listener = invocation.getArgument[BlockFetchingListener](4) + listener.onBlockFetchSuccess(ShuffleBlockId(0, 0, 0).toString, mockCorruptBuffer()) + } + + val logAppender = new LogAppender("diagnose corruption") + withLogAppender(logAppender) { + val iterator = createShuffleBlockIteratorWithDefaults( + Map(remoteBmId -> toBlockList(blocks.keys, 1L, 0)), + streamWrapperLimitSize = Some(100) + ) + intercept[FetchFailedException](iterator.next()) + // The block will be fetched twice due to retry + verify(transfer, times(2)) + .fetchBlocks(any(), any(), any(), any(), any(), any()) + // only diagnose once + assert(logAppender.loggingEvents.count( + _.getRenderedMessage.contains("Start corruption diagnosis")) === 1) + } + } + + test("SPARK-36206: diagnose the block when it's corrupted " + + "inside BufferReleasingInputStream") { + // Make sure remote blocks would return + val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) + val blocks = Map[BlockId, ManagedBuffer]( + ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer() + ) + answerFetchBlocks { invocation => + val listener = invocation.getArgument[BlockFetchingListener](4) + listener.onBlockFetchSuccess( + ShuffleBlockId(0, 0, 0).toString, + mockCorruptBuffer(100, 50)) + } + + val logAppender = new LogAppender("diagnose corruption") + withLogAppender(logAppender) { + val iterator = createShuffleBlockIteratorWithDefaults( + Map(remoteBmId -> toBlockList(blocks.keys, 1L, 0)), + streamWrapperLimitSize = Some(100), + maxBytesInFlight = 100 + ) + intercept[FetchFailedException] { + val inputStream = iterator.next()._2 + // Consume the data to trigger the corruption + ByteStreams.readFully(inputStream, new Array[Byte](100)) + } + // The block will be fetched only once because corruption can't be detected in + // maxBytesInFlight/3 of the data size + verify(transfer, times(1)) + .fetchBlocks(any(), any(), any(), any(), any(), any()) + // only diagnose once + assert(logAppender.loggingEvents.exists( + _.getRenderedMessage.contains("Start corruption diagnosis"))) + } + } + test("successful 3 local + 4 host local + 2 remote reads") { val blockManager = createMockBlockManager() val localBmId = blockManager.blockManagerId