Skip to content

Commit

Permalink
[SPARK-24552][CORE][SQL] Use task ID instead of attempt number for wr…
Browse files Browse the repository at this point in the history
…ites.

This passes the unique task attempt id instead of attempt number to v2 data sources because attempt number is reused when stages are retried. When attempt numbers are reused, sources that track data by partition id and attempt number may incorrectly clean up data because the same attempt number can be both committed and aborted.

For v1 / Hadoop writes, generate a unique ID based on available attempt numbers to avoid a similar problem.

Closes apache#21558

Author: Marcelo Vanzin <vanzin@cloudera.com>
Author: Ryan Blue <blue@apache.org>

Closes apache#21606 from vanzin/SPARK-24552.2.

Ref: LIHADOOP-48531
  • Loading branch information
Marcelo Vanzin authored and otterc committed Oct 25, 2019
1 parent 1f3e9fe commit ce746ec
Show file tree
Hide file tree
Showing 11 changed files with 45 additions and 37 deletions.
Expand Up @@ -76,13 +76,17 @@ object SparkHadoopWriter extends Logging {
// Try to write all RDD partitions as a Hadoop OutputFormat.
try {
val ret = sparkContext.runJob(rdd, (context: TaskContext, iter: Iterator[(K, V)]) => {
// SPARK-24552: Generate a unique "attempt ID" based on the stage and task attempt numbers.
// Assumes that there won't be more than Short.MaxValue attempts, at least not concurrently.
val attemptId = (context.stageAttemptNumber << 16) | context.attemptNumber

executeTask(
context = context,
config = config,
jobTrackerId = jobTrackerId,
commitJobId = commitJobId,
sparkPartitionId = context.partitionId,
sparkAttemptNumber = context.attemptNumber,
sparkAttemptNumber = attemptId,
committer = committer,
iterator = iter)
})
Expand Down
Expand Up @@ -67,7 +67,7 @@ case class KafkaStreamWriterFactory(

override def createDataWriter(
partitionId: Int,
attemptNumber: Int,
taskId: Long,
epochId: Long): DataWriter[InternalRow] = {
new KafkaStreamDataWriter(topic, producerParams, schema.toAttributes)
}
Expand Down
Expand Up @@ -64,8 +64,8 @@ public interface DataSourceWriter {
DataWriterFactory<Row> createWriterFactory();

/**
* Returns whether Spark should use the commit coordinator to ensure that at most one attempt for
* each task commits.
* Returns whether Spark should use the commit coordinator to ensure that at most one task for
* each partition commits.
*
* @return true if commit coordinator should be used, false otherwise.
*/
Expand All @@ -90,9 +90,9 @@ default void onDataWriterCommit(WriterCommitMessage message) {}
* is undefined and @{@link #abort(WriterCommitMessage[])} may not be able to deal with it.
*
* Note that speculative execution may cause multiple tasks to run for a partition. By default,
* Spark uses the commit coordinator to allow at most one attempt to commit. Implementations can
* Spark uses the commit coordinator to allow at most one task to commit. Implementations can
* disable this behavior by overriding {@link #useCommitCoordinator()}. If disabled, multiple
* attempts may have committed successfully and one successful commit message per task will be
* tasks may have committed successfully and one successful commit message per task will be
* passed to this commit method. The remaining commit messages are ignored by Spark.
*/
void commit(WriterCommitMessage[] messages);
Expand Down
Expand Up @@ -22,7 +22,7 @@
import org.apache.spark.annotation.InterfaceStability;

/**
* A data writer returned by {@link DataWriterFactory#createDataWriter(int, int, long)} and is
* A data writer returned by {@link DataWriterFactory#createDataWriter(int, long, long)} and is
* responsible for writing data for an input RDD partition.
*
* One Spark task has one exclusive data writer, so there is no thread-safe concern.
Expand All @@ -39,14 +39,14 @@
* {@link DataSourceWriter#commit(WriterCommitMessage[])} with commit messages from other data
* writers. If this data writer fails(one record fails to write or {@link #commit()} fails), an
* exception will be sent to the driver side, and Spark may retry this writing task a few times.
* In each retry, {@link DataWriterFactory#createDataWriter(int, int, long)} will receive a
* different `attemptNumber`. Spark will call {@link DataSourceWriter#abort(WriterCommitMessage[])}
* In each retry, {@link DataWriterFactory#createDataWriter(int, long, long)} will receive a
* different `taskId`. Spark will call {@link DataSourceWriter#abort(WriterCommitMessage[])}
* when the configured number of retries is exhausted.
*
* Besides the retry mechanism, Spark may launch speculative tasks if the existing writing task
* takes too long to finish. Different from retried tasks, which are launched one by one after the
* previous one fails, speculative tasks are running simultaneously. It's possible that one input
* RDD partition has multiple data writers with different `attemptNumber` running at the same time,
* RDD partition has multiple data writers with different `taskId` running at the same time,
* and data sources should guarantee that these data writers don't conflict and can work together.
* Implementations can coordinate with driver during {@link #commit()} to make sure only one of
* these data writers can commit successfully. Or implementations can allow all of them to commit
Expand Down
Expand Up @@ -42,15 +42,12 @@ public interface DataWriterFactory<T> extends Serializable {
* Usually Spark processes many RDD partitions at the same time,
* implementations should use the partition id to distinguish writers for
* different partitions.
* @param attemptNumber Spark may launch multiple tasks with the same task id. For example, a task
* failed, Spark launches a new task wth the same task id but different
* attempt number. Or a task is too slow, Spark launches new tasks wth the
* same task id but different attempt number, which means there are multiple
* tasks with the same task id running at the same time. Implementations can
* use this attempt number to distinguish writers of different task attempts.
* @param taskId A unique identifier for a task that is performing the write of the partition
* data. Spark may run multiple tasks for the same partition (due to speculation
* or task failures, for example).
* @param epochId A monotonically increasing id for streaming queries that are split in to
* discrete periods of execution. For non-streaming queries,
* this ID will always be 0.
*/
DataWriter<T> createDataWriter(int partitionId, int attemptNumber, long epochId);
DataWriter<T> createDataWriter(int partitionId, long taskId, long epochId);
}
Expand Up @@ -29,10 +29,8 @@ import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.streaming.{MicroBatchExecution, StreamExecution}
import org.apache.spark.sql.execution.streaming.continuous.{CommitPartitionEpoch, ContinuousExecution, EpochCoordinatorRef, SetWriterPartitions}
import org.apache.spark.sql.execution.streaming.MicroBatchExecution
import org.apache.spark.sql.sources.v2.writer._
import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.Utils

Expand Down Expand Up @@ -109,23 +107,29 @@ object DataWritingSparkTask extends Logging {
iter: Iterator[InternalRow],
useCommitCoordinator: Boolean): WriterCommitMessage = {
val stageId = context.stageId()
val stageAttempt = context.stageAttemptNumber()
val partId = context.partitionId()
val taskId = context.taskAttemptId()
val attemptId = context.attemptNumber()
val epochId = Option(context.getLocalProperty(MicroBatchExecution.BATCH_ID_KEY)).getOrElse("0")
val dataWriter = writeTask.createDataWriter(partId, attemptId, epochId.toLong)
val dataWriter = writeTask.createDataWriter(partId, taskId, epochId.toLong)

// write the data and commit this writer.
Utils.tryWithSafeFinallyAndFailureCallbacks(block = {
iter.foreach(dataWriter.write)
while (iter.hasNext) {
dataWriter.write(iter.next())
}

val msg = if (useCommitCoordinator) {
val coordinator = SparkEnv.get.outputCommitCoordinator
val commitAuthorized = coordinator.canCommit(context.stageId(), partId, attemptId)
if (commitAuthorized) {
logInfo(s"Writer for stage $stageId, task $partId.$attemptId is authorized to commit.")
logInfo(s"Commit authorized for partition $partId (task $taskId, attempt $attemptId" +
s"stage $stageId.$stageAttempt)")
dataWriter.commit()
} else {
val message = s"Stage $stageId, task $partId.$attemptId: driver did not authorize commit"
val message = s"Commit denied for partition $partId (task $taskId, attempt $attemptId" +
s"stage $stageId.$stageAttempt)"
logInfo(message)
// throwing CommitDeniedException will trigger the catch block for abort
throw new CommitDeniedException(message, stageId, partId, attemptId)
Expand All @@ -136,15 +140,18 @@ object DataWritingSparkTask extends Logging {
dataWriter.commit()
}

logInfo(s"Writer for stage $stageId, task $partId.$attemptId committed.")
logInfo(s"Committed partition $partId (task $taskId, attempt $attemptId" +
s"stage $stageId.$stageAttempt)")

msg

})(catchBlock = {
// If there is an error, abort this writer
logError(s"Writer for stage $stageId, task $partId.$attemptId is aborting.")
logError(s"Aborting commit for partition $partId (task $taskId, attempt $attemptId" +
s"stage $stageId.$stageAttempt)")
dataWriter.abort()
logError(s"Writer for stage $stageId, task $partId.$attemptId aborted.")
logError(s"Aborted commit for partition $partId (task $taskId, attempt $attemptId" +
s"stage $stageId.$stageAttempt)")
})
}
}
Expand All @@ -155,10 +162,10 @@ class InternalRowDataWriterFactory(

override def createDataWriter(
partitionId: Int,
attemptNumber: Int,
taskId: Long,
epochId: Long): DataWriter[InternalRow] = {
new InternalRowDataWriter(
rowWriterFactory.createDataWriter(partitionId, attemptNumber, epochId),
rowWriterFactory.createDataWriter(partitionId, taskId, epochId),
RowEncoder.apply(schema).resolveAndBind())
}
}
Expand Down
Expand Up @@ -53,7 +53,7 @@ class ContinuousWriteRDD(var prev: RDD[InternalRow], writeTask: DataWriterFactor
val dataIterator = prev.compute(split, context)
dataWriter = writeTask.createDataWriter(
context.partitionId(),
context.attemptNumber(),
context.taskAttemptId(),
EpochTracker.getCurrentEpoch.get)
while (dataIterator.hasNext) {
dataWriter.write(dataIterator.next())
Expand Down
Expand Up @@ -88,7 +88,7 @@ case class ForeachWriterFactory[T](
extends DataWriterFactory[InternalRow] {
override def createDataWriter(
partitionId: Int,
attemptNumber: Int,
taskId: Long,
epochId: Long): ForeachDataWriter[T] = {
new ForeachDataWriter(writer, rowConverter, partitionId, epochId)
}
Expand Down
Expand Up @@ -33,7 +33,7 @@ import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriter, Dat
case object PackedRowWriterFactory extends DataWriterFactory[Row] {
override def createDataWriter(
partitionId: Int,
attemptNumber: Int,
taskId: Long,
epochId: Long): DataWriter[Row] = {
new PackedRowDataWriter()
}
Expand Down
Expand Up @@ -179,7 +179,7 @@ class MemoryStreamWriter(
case class MemoryWriterFactory(outputMode: OutputMode) extends DataWriterFactory[Row] {
override def createDataWriter(
partitionId: Int,
attemptNumber: Int,
taskId: Long,
epochId: Long): DataWriter[Row] = {
new MemoryDataWriter(partitionId, outputMode)
}
Expand Down
Expand Up @@ -209,10 +209,10 @@ class SimpleCSVDataWriterFactory(path: String, jobId: String, conf: Serializable

override def createDataWriter(
partitionId: Int,
attemptNumber: Int,
taskId: Long,
epochId: Long): DataWriter[Row] = {
val jobPath = new Path(new Path(path, "_temporary"), jobId)
val filePath = new Path(jobPath, s"$jobId-$partitionId-$attemptNumber")
val filePath = new Path(jobPath, s"$jobId-$partitionId-$taskId")
val fs = filePath.getFileSystem(conf.value)
new SimpleCSVDataWriter(fs, filePath)
}
Expand Down Expand Up @@ -245,10 +245,10 @@ class InternalRowCSVDataWriterFactory(path: String, jobId: String, conf: Seriali

override def createDataWriter(
partitionId: Int,
attemptNumber: Int,
taskId: Long,
epochId: Long): DataWriter[InternalRow] = {
val jobPath = new Path(new Path(path, "_temporary"), jobId)
val filePath = new Path(jobPath, s"$jobId-$partitionId-$attemptNumber")
val filePath = new Path(jobPath, s"$jobId-$partitionId-$taskId")
val fs = filePath.getFileSystem(conf.value)
new InternalRowCSVDataWriter(fs, filePath)
}
Expand Down

0 comments on commit ce746ec

Please sign in to comment.