Skip to content

Commit

Permalink
Merge pull request #416 from stephenh/morefinally
Browse files Browse the repository at this point in the history
Call executeOnCompleteCallbacks in more finally blocks.
  • Loading branch information
mateiz committed Jan 25, 2013
2 parents 04bfee2 + 8efbda0 commit 2435b7b
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 29 deletions.
13 changes: 7 additions & 6 deletions core/src/main/scala/spark/scheduler/DAGScheduler.scala
Expand Up @@ -40,7 +40,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
eventQueue.put(HostLost(host))
}

// Called by TaskScheduler to cancel an entier TaskSet due to repeated failures.
// Called by TaskScheduler to cancel an entire TaskSet due to repeated failures.
override def taskSetFailed(taskSet: TaskSet, reason: String) {
eventQueue.put(TaskSetFailed(taskSet, reason))
}
Expand All @@ -54,8 +54,6 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
// resubmit failed stages
val POLL_TIMEOUT = 10L

private val lock = new Object // Used for access to the entire DAGScheduler

private val eventQueue = new LinkedBlockingQueue[DAGSchedulerEvent]

val nextRunId = new AtomicInteger(0)
Expand Down Expand Up @@ -337,9 +335,12 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
val rdd = job.finalStage.rdd
val split = rdd.splits(job.partitions(0))
val taskContext = new TaskContext(job.finalStage.id, job.partitions(0), 0)
val result = job.func(taskContext, rdd.iterator(split, taskContext))
taskContext.executeOnCompleteCallbacks()
job.listener.taskSucceeded(0, result)
try {
val result = job.func(taskContext, rdd.iterator(split, taskContext))
job.listener.taskSucceeded(0, result)
} finally {
taskContext.executeOnCompleteCallbacks()
}
} catch {
case e: Exception =>
job.listener.jobFailed(e)
Expand Down
46 changes: 23 additions & 23 deletions core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
Expand Up @@ -81,7 +81,7 @@ private[spark] class ShuffleMapTask(
with Externalizable
with Logging {

def this() = this(0, null, null, 0, null)
protected def this() = this(0, null, null, 0, null)

var split = if (rdd == null) {
null
Expand Down Expand Up @@ -117,34 +117,34 @@ private[spark] class ShuffleMapTask(

override def run(attemptId: Long): MapStatus = {
val numOutputSplits = dep.partitioner.numPartitions
val partitioner = dep.partitioner

val taskContext = new TaskContext(stageId, partition, attemptId)
try {
// Partition the map output.
val buckets = Array.fill(numOutputSplits)(new ArrayBuffer[(Any, Any)])
for (elem <- rdd.iterator(split, taskContext)) {
val pair = elem.asInstanceOf[(Any, Any)]
val bucketId = dep.partitioner.getPartition(pair._1)
buckets(bucketId) += pair
}
val bucketIterators = buckets.map(_.iterator)

// Partition the map output.
val buckets = Array.fill(numOutputSplits)(new ArrayBuffer[(Any, Any)])
for (elem <- rdd.iterator(split, taskContext)) {
val pair = elem.asInstanceOf[(Any, Any)]
val bucketId = partitioner.getPartition(pair._1)
buckets(bucketId) += pair
}
val bucketIterators = buckets.map(_.iterator)
val compressedSizes = new Array[Byte](numOutputSplits)

val compressedSizes = new Array[Byte](numOutputSplits)
val blockManager = SparkEnv.get.blockManager
for (i <- 0 until numOutputSplits) {
val blockId = "shuffle_" + dep.shuffleId + "_" + partition + "_" + i
// Get a Scala iterator from Java map
val iter: Iterator[(Any, Any)] = bucketIterators(i)
val size = blockManager.put(blockId, iter, StorageLevel.DISK_ONLY, false)
compressedSizes(i) = MapOutputTracker.compressSize(size)
}

val blockManager = SparkEnv.get.blockManager
for (i <- 0 until numOutputSplits) {
val blockId = "shuffle_" + dep.shuffleId + "_" + partition + "_" + i
// Get a Scala iterator from Java map
val iter: Iterator[(Any, Any)] = bucketIterators(i)
val size = blockManager.put(blockId, iter, StorageLevel.DISK_ONLY, false)
compressedSizes(i) = MapOutputTracker.compressSize(size)
return new MapStatus(blockManager.blockManagerId, compressedSizes)
} finally {
// Execute the callbacks on task completion.
taskContext.executeOnCompleteCallbacks()
}

// Execute the callbacks on task completion.
taskContext.executeOnCompleteCallbacks()

return new MapStatus(blockManager.blockManagerId, compressedSizes)
}

override def preferredLocations: Seq[String] = locs
Expand Down

0 comments on commit 2435b7b

Please sign in to comment.