diff --git a/core/pom.xml b/core/pom.xml
index fc42f48973fe9..e36d1a45aa9a2 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -361,6 +361,16 @@
junittest
+
+ org.hamcrest
+ hamcrest-core
+ test
+
+
+ org.hamcrest
+ hamcrest-library
+ test
+ com.novocodejunit-interface
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/TimeTrackingOutputStream.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/TimeTrackingOutputStream.java
new file mode 100644
index 0000000000000..8b5ba49e67204
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/TimeTrackingOutputStream.java
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.unsafe;
+
+import org.apache.spark.executor.ShuffleWriteMetrics;
+
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.OutputStream;
+
+/**
+ * Intercepts write calls and tracks total time spent writing.
+ */
+final class TimeTrackingFileOutputStream extends OutputStream {
+
+ private final ShuffleWriteMetrics writeMetrics;
+ private final FileOutputStream outputStream;
+
+ public TimeTrackingFileOutputStream(
+ ShuffleWriteMetrics writeMetrics,
+ FileOutputStream outputStream) {
+ this.writeMetrics = writeMetrics;
+ this.outputStream = outputStream;
+ }
+
+ @Override
+ public void write(int b) throws IOException {
+ final long startTime = System.nanoTime();
+ outputStream.write(b);
+ writeMetrics.incShuffleWriteTime(System.nanoTime() - startTime);
+ }
+
+ @Override
+ public void write(byte[] b) throws IOException {
+ final long startTime = System.nanoTime();
+ outputStream.write(b);
+ writeMetrics.incShuffleWriteTime(System.nanoTime() - startTime); }
+
+ @Override
+ public void write(byte[] b, int off, int len) throws IOException {
+ final long startTime = System.nanoTime();
+ outputStream.write(b, off, len);
+ writeMetrics.incShuffleWriteTime(System.nanoTime() - startTime);
+ }
+
+ @Override
+ public void flush() throws IOException {
+ outputStream.flush();
+ }
+
+ @Override
+ public void close() throws IOException {
+ outputStream.close();
+ }
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java
index 1d1382c104fea..c4d26288de33d 100644
--- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java
@@ -70,6 +70,7 @@ final class UnsafeShuffleExternalSorter {
private final BlockManager blockManager;
private final TaskContext taskContext;
private final boolean spillingEnabled;
+ private final ShuffleWriteMetrics writeMetrics;
/** The buffer size to use when writing spills using DiskBlockObjectWriter */
private final int fileBufferSize;
@@ -97,7 +98,8 @@ public UnsafeShuffleExternalSorter(
TaskContext taskContext,
int initialSize,
int numPartitions,
- SparkConf conf) throws IOException {
+ SparkConf conf,
+ ShuffleWriteMetrics writeMetrics) throws IOException {
this.memoryManager = memoryManager;
this.shuffleMemoryManager = shuffleMemoryManager;
this.blockManager = blockManager;
@@ -107,6 +109,7 @@ public UnsafeShuffleExternalSorter(
this.spillingEnabled = conf.getBoolean("spark.shuffle.spill", true);
// Use getSizeAsKb (not bytes) to maintain backwards compatibility for units
this.fileBufferSize = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024;
+ this.writeMetrics = writeMetrics;
openSorter();
}
@@ -131,8 +134,24 @@ private void openSorter() throws IOException {
/**
* Sorts the in-memory records and writes the sorted records to a spill file.
* This method does not free the sort data structures.
+ *
+ * @param isSpill if true, this indicates that we're writing a spill and that bytes written should
+ * be counted towards shuffle spill metrics rather than shuffle write metrics.
*/
- private SpillInfo writeSpillFile() throws IOException {
+ private void writeSpillFile(boolean isSpill) throws IOException {
+
+ final ShuffleWriteMetrics writeMetricsToUse;
+
+ if (isSpill) {
+ // We're spilling, so bytes written should be counted towards spill rather than write.
+ // Create a dummy WriteMetrics object to absorb these metrics, since we don't want to count
+ // them towards shuffle bytes written.
+ writeMetricsToUse = new ShuffleWriteMetrics();
+ } else {
+ // We're writing the final non-spill file, so we _do_ want to count this as shuffle bytes.
+ writeMetricsToUse = writeMetrics;
+ }
+
// This call performs the actual sort.
final UnsafeShuffleSorter.UnsafeShuffleSorterIterator sortedRecords =
sorter.getSortedIterator();
@@ -161,17 +180,8 @@ private SpillInfo writeSpillFile() throws IOException {
// OutputStream methods), but DiskBlockObjectWriter still calls some methods on it. To work
// around this, we pass a dummy no-op serializer.
final SerializerInstance ser = DummySerializerInstance.INSTANCE;
- // TODO: audit the metrics-related code and ensure proper metrics integration:
- // It's not clear how we should handle shuffle write metrics for spill files; currently, Spark
- // doesn't report IO time spent writing spill files (see SPARK-7413). This method,
- // writeSpillFile(), is called both when writing spill files and when writing the single output
- // file in cases where we didn't spill. As a result, we don't necessarily know whether this
- // should be reported as bytes spilled or as shuffle bytes written. We could defer the updating
- // of these metrics until the end of the shuffle write, but that would mean that that users
- // wouldn't get useful metrics updates in the UI from long-running tasks. Given this complexity,
- // I'm deferring these decisions to a separate follow-up commit or patch.
- writer =
- blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, new ShuffleWriteMetrics());
+
+ writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, writeMetricsToUse);
int currentPartition = -1;
while (sortedRecords.hasNext()) {
@@ -185,8 +195,7 @@ private SpillInfo writeSpillFile() throws IOException {
spillInfo.partitionLengths[currentPartition] = writer.fileSegment().length();
}
currentPartition = partition;
- writer =
- blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, new ShuffleWriteMetrics());
+ writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, writeMetricsToUse);
}
final long recordPointer = sortedRecords.packedRecordPointer.getRecordPointer();
@@ -220,7 +229,14 @@ private SpillInfo writeSpillFile() throws IOException {
spills.add(spillInfo);
}
}
- return spillInfo;
+
+ if (isSpill) {
+ writeMetrics.incShuffleRecordsWritten(writeMetricsToUse.shuffleRecordsWritten());
+ // Consistent with ExternalSorter, we do not count this IO towards shuffle write time.
+ // This means that this IO time is not accounted for anywhere; SPARK-3577 will fix this.
+ // writeMetrics.incShuffleWriteTime(writeMetricsToUse.shuffleWriteTime());
+ taskContext.taskMetrics().incDiskBytesSpilled(writeMetricsToUse.shuffleBytesWritten());
+ }
}
/**
@@ -233,13 +249,12 @@ void spill() throws IOException {
org.apache.spark.util.Utils.bytesToString(getMemoryUsage()) + " to disk (" +
(spills.size() + (spills.size() > 1 ? " times" : " time")) + " so far)");
- final SpillInfo spillInfo = writeSpillFile();
+ writeSpillFile(true);
final long sorterMemoryUsage = sorter.getMemoryUsage();
sorter = null;
shuffleMemoryManager.release(sorterMemoryUsage);
final long spillSize = freeMemory();
taskContext.taskMetrics().incMemoryBytesSpilled(spillSize);
- taskContext.taskMetrics().incDiskBytesSpilled(spillInfo.file.length());
openSorter();
}
@@ -389,7 +404,8 @@ public void insertRecord(
public SpillInfo[] closeAndGetSpills() throws IOException {
try {
if (sorter != null) {
- writeSpillFile();
+ // Do not count the final file towards the spill count.
+ writeSpillFile(false);
freeMemory();
}
return spills.toArray(new SpillInfo[spills.size()]);
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java
index 02bf7e321df12..7544ebbfeaad5 100644
--- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java
@@ -157,7 +157,8 @@ private void open() throws IOException {
taskContext,
INITIAL_SORT_BUFFER_SIZE,
partitioner.numPartitions(),
- sparkConf);
+ sparkConf,
+ writeMetrics);
serBuffer = new MyByteArrayOutputStream(1024 * 1024);
serOutputStream = serializer.serializeStream(serBuffer);
}
@@ -210,6 +211,12 @@ void forceSorterToSpill() throws IOException {
sorter.spill();
}
+ /**
+ * Merge zero or more spill files together, choosing the fastest merging strategy based on the
+ * number of spills and the IO compression codec.
+ *
+ * @return the partition lengths in the merged file.
+ */
private long[] mergeSpills(SpillInfo[] spills) throws IOException {
final File outputFile = shuffleBlockResolver.getDataFile(shuffleId, mapId);
final boolean compressionEnabled = sparkConf.getBoolean("spark.shuffle.compress", true);
@@ -223,30 +230,42 @@ private long[] mergeSpills(SpillInfo[] spills) throws IOException {
new FileOutputStream(outputFile).close(); // Create an empty file
return new long[partitioner.numPartitions()];
} else if (spills.length == 1) {
- // Note: we'll have to watch out for corner-cases in this code path when working on shuffle
- // metrics integration, since any metrics updates that are performed during the merge will
- // also have to be done here. In this branch, the shuffle technically didn't need to spill
- // because we're only trying to merge one file, so we may need to ensure that metrics that
- // would otherwise be counted as spill metrics are actually counted as regular write
- // metrics.
+ // Here, we don't need to perform any metrics updates because the bytes written to this
+ // output file would have already been counted as shuffle bytes written.
Files.move(spills[0].file, outputFile);
return spills[0].partitionLengths;
} else {
+ final long[] partitionLengths;
+ // There are multiple spills to merge, so none of these spill files' lengths were counted
+ // towards our shuffle write count or shuffle write time. If we use the slow merge path,
+ // then the final output file's size won't necessarily be equal to the sum of the spill
+ // files' sizes. To guard against this case, we look at the output file's actual size when
+ // computing shuffle bytes written.
+ //
+ // We allow the individual merge methods to report their own IO times since different merge
+ // strategies use different IO techniques. We count IO during merge towards the shuffle
+ // shuffle write time, which appears to be consistent with the "not bypassing merge-sort"
+ // branch in ExternalSorter.
if (fastMergeEnabled && fastMergeIsSupported) {
// Compression is disabled or we are using an IO compression codec that supports
// decompression of concatenated compressed streams, so we can perform a fast spill merge
// that doesn't need to interpret the spilled bytes.
if (transferToEnabled) {
logger.debug("Using transferTo-based fast merge");
- return mergeSpillsWithTransferTo(spills, outputFile);
+ partitionLengths = mergeSpillsWithTransferTo(spills, outputFile);
} else {
logger.debug("Using fileStream-based fast merge");
- return mergeSpillsWithFileStream(spills, outputFile, null);
+ partitionLengths = mergeSpillsWithFileStream(spills, outputFile, null);
}
} else {
logger.debug("Using slow merge");
- return mergeSpillsWithFileStream(spills, outputFile, compressionCodec);
+ partitionLengths = mergeSpillsWithFileStream(spills, outputFile, compressionCodec);
}
+ // The final shuffle spill's write would have directly updated shuffleBytesWritten, so
+ // we need to decrement to avoid double-counting this write.
+ writeMetrics.decShuffleBytesWritten(spills[spills.length - 1].file.length());
+ writeMetrics.incShuffleBytesWritten(outputFile.length());
+ return partitionLengths;
}
} catch (IOException e) {
if (outputFile.exists() && !outputFile.delete()) {
@@ -271,7 +290,8 @@ private long[] mergeSpillsWithFileStream(
}
for (int partition = 0; partition < numPartitions; partition++) {
final long initialFileLength = outputFile.length();
- mergedFileOutputStream = new FileOutputStream(outputFile, true);
+ mergedFileOutputStream =
+ new TimeTrackingFileOutputStream(writeMetrics, new FileOutputStream(outputFile, true));
if (compressionCodec != null) {
mergedFileOutputStream = compressionCodec.compressedOutputStream(mergedFileOutputStream);
}
@@ -321,6 +341,7 @@ private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) th
final long partitionLengthInSpill = spills[i].partitionLengths[partition];
long bytesToTransfer = partitionLengthInSpill;
final FileChannel spillInputChannel = spillInputChannels[i];
+ final long writeStartTime = System.nanoTime();
while (bytesToTransfer > 0) {
final long actualBytesTransferred = spillInputChannel.transferTo(
spillInputChannelPositions[i],
@@ -329,6 +350,7 @@ private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) th
spillInputChannelPositions[i] += actualBytesTransferred;
bytesToTransfer -= actualBytesTransferred;
}
+ writeMetrics.incShuffleWriteTime(System.nanoTime() - writeStartTime);
bytesWrittenToMergedFile += partitionLengthInSpill;
partitionLengths[partition] += partitionLengthInSpill;
}
diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java
index c53e0fcf44880..8451e8d9a9785 100644
--- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java
+++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java
@@ -29,13 +29,16 @@
import com.google.common.collect.HashMultiset;
import com.google.common.io.ByteStreams;
import org.junit.After;
-import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
+import static org.junit.Assert.*;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.greaterThan;
+import static org.hamcrest.Matchers.lessThan;
import static org.mockito.AdditionalAnswers.returnsFirstArg;
import static org.mockito.Answers.RETURNS_SMART_NULLS;
import static org.mockito.Mockito.*;
@@ -70,6 +73,7 @@ public class UnsafeShuffleWriterSuite {
final LinkedList spillFilesCreated = new LinkedList();
SparkConf conf;
final Serializer serializer = new KryoSerializer(new SparkConf());
+ TaskMetrics taskMetrics;
@Mock(answer = RETURNS_SMART_NULLS) ShuffleMemoryManager shuffleMemoryManager;
@Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager;
@@ -94,7 +98,7 @@ public void tearDown() {
Utils.deleteRecursively(tempDir);
final long leakedMemory = taskMemoryManager.cleanUpAllAllocatedMemory();
if (leakedMemory != 0) {
- Assert.fail("Test leaked " + leakedMemory + " bytes of managed memory");
+ fail("Test leaked " + leakedMemory + " bytes of managed memory");
}
}
@@ -107,6 +111,7 @@ public void setUp() throws IOException {
partitionSizesInMergedFile = null;
spillFilesCreated.clear();
conf = new SparkConf();
+ taskMetrics = new TaskMetrics();
when(shuffleMemoryManager.tryToAcquire(anyLong())).then(returnsFirstArg());
@@ -183,7 +188,7 @@ public Tuple2 answer(
}
});
- when(taskContext.taskMetrics()).thenReturn(new TaskMetrics());
+ when(taskContext.taskMetrics()).thenReturn(taskMetrics);
when(shuffleDep.serializer()).thenReturn(Option.apply(serializer));
when(shuffleDep.partitioner()).thenReturn(hashPartitioner);
@@ -205,7 +210,7 @@ private UnsafeShuffleWriter
+
+ org.hamcrest
+ hamcrest-core
+ 1.3
+ test
+
+
+ org.hamcrest
+ hamcrest-library
+ 1.3
+ test
+ com.novocodejunit-interface