Skip to content

Commit

Permalink
apache#13 fix block miss when df reusered
Browse files Browse the repository at this point in the history
  • Loading branch information
hn5092 committed Jan 21, 2019
1 parent 985213b commit 5112cff
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ private[spark] class BroadcastManager(
securityManager: SecurityManager)
extends Logging {

val cleanQueryBroadcast = conf.getBoolean("spark.broadcast.autoClean.enabled", false)

private var initialized = false
private var broadcastFactory: BroadcastFactory = null
var cachedBroadcast = new ConcurrentHashMap[String, ListBuffer[Long]]()
Expand Down Expand Up @@ -65,15 +67,16 @@ private[spark] class BroadcastManager(
}

def cleanBroadCast(executionId: String): Unit = {
if (cachedBroadcast.containsKey(executionId)) {
cachedBroadcast.get(executionId).foreach(broadcastId => unbroadcast(broadcastId, true, false))
cachedBroadcast.remove(executionId)
}
if (cachedBroadcast.containsKey(executionId)) {
cachedBroadcast.get(executionId)
.foreach(broadcastId => unbroadcast(broadcastId, true, false))
cachedBroadcast.remove(executionId)
}
}

def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, executionId: String): Broadcast[T] = {
val broadcastId = nextBroadcastId.getAndIncrement()
if (executionId != null) {
if (executionId != null && cleanQueryBroadcast) {
if (cachedBroadcast.containsKey(executionId)) {
cachedBroadcast.get(executionId) += broadcastId
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -692,7 +692,8 @@ private[spark] class DAGScheduler(
// Return immediately if the job is running 0 tasks
return new JobWaiter[U](this, jobId, 0, resultHandler)
}

val executionId = sc.getLocalProperty("spark.sql.execution.id")
logInfo(s"submit job : $jobId, executionId is $executionId")
assert(partitions.size > 0)
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
val waiter = new JobWaiter(this, jobId, partitions.size, resultHandler)
Expand Down Expand Up @@ -1082,6 +1083,7 @@ private[spark] class DAGScheduler(
/** Called when stage's parents are available and we can now do its task. */
private def submitMissingTasks(stage: Stage, jobId: Int) {
logDebug("submitMissingTasks(" + stage + ")")
logInfo(s"submit stage ${stage.id} with jobId: $jobId")

// First figure out the indexes of partition ids to compute.
val partitionsToCompute: Seq[Int] = stage.findMissingPartitions()
Expand Down
22 changes: 11 additions & 11 deletions core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -530,18 +530,18 @@ private[spark] class TaskSetManager(
private def maybeFinishTaskSet() {
if (isZombie && runningTasks == 0) {
sched.taskSetFinished(this)
val broadcastId = taskSet.tasks.head match {
case resultTask: ResultTask[Any, Any] =>
resultTask.taskBinary.id
case shuffleMapTask: ShuffleMapTask =>
shuffleMapTask.taskBinary.id
}
SparkEnv.get.broadcastManager.unbroadcast(broadcastId, true, false)
if (tasksSuccessful == numTasks) {
blacklistTracker.foreach(_.updateBlacklistForSuccessfulTaskSet(
taskSet.stageId,
taskSet.stageAttemptId,
taskSetBlacklistHelperOpt.get.execToFailures))
val broadcastId = taskSet.tasks.head match {
case resultTask: ResultTask[Any, Any] =>
resultTask.taskBinary.id
case shuffleMapTask: ShuffleMapTask =>
shuffleMapTask.taskBinary.id
}
SparkEnv.get.broadcastManager.unbroadcast(broadcastId, true, false)
blacklistTracker.foreach(_.updateBlacklistForSuccessfulTaskSet(
taskSet.stageId,
taskSet.stageAttemptId,
taskSetBlacklistHelperOpt.get.execToFailures))
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@ import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.AtomicLong

import org.apache.spark.SparkEnv
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.execution.ui.{SparkListenerSQLExecutionEnd, SparkListenerSQLExecutionStart}

object SQLExecution {

object SQLExecution extends Logging{

val EXECUTION_ID_KEY = "spark.sql.execution.id"

Expand Down Expand Up @@ -62,6 +64,7 @@ object SQLExecution {
val sc = sparkSession.sparkContext
val oldExecutionId = sc.getLocalProperty(EXECUTION_ID_KEY)
val executionId = SQLExecution.nextExecutionId
logInfo(s"Execution Id is $executionId ")
sc.setLocalProperty(EXECUTION_ID_KEY, executionId.toString)
executionIdToQueryExecution.put(executionId, queryExecution)
try {
Expand Down

0 comments on commit 5112cff

Please sign in to comment.