Skip to content

Commit

Permalink
merge master
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Apr 9, 2015
2 parents d63b5cc + 9418280 commit 55be1f3
Show file tree
Hide file tree
Showing 67 changed files with 471 additions and 412 deletions.
29 changes: 0 additions & 29 deletions core/src/main/scala/org/apache/spark/TaskContextHelper.scala

This file was deleted.

67 changes: 36 additions & 31 deletions core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat
* not caused by shuffle file loss are handled by the TaskScheduler, which will retry each task
* a small number of times before cancelling the whole stage.
*
* Here's a checklist to use when making or reviewing changes to this class:
*
* - When adding a new data structure, update `DAGSchedulerSuite.assertDataStructuresEmpty` to
* include the new structure. This will help to catch memory leaks.
*/
private[spark]
class DAGScheduler(
Expand Down Expand Up @@ -111,6 +115,8 @@ class DAGScheduler(
// stray messages to detect.
private val failedEpoch = new HashMap[String, Long]

private [scheduler] val outputCommitCoordinator = env.outputCommitCoordinator

// A closure serializer that we reuse.
// This is only safe because DAGScheduler runs in a single thread.
private val closureSerializer = SparkEnv.get.closureSerializer.newInstance()
Expand All @@ -128,8 +134,6 @@ class DAGScheduler(
private[scheduler] val eventProcessLoop = new DAGSchedulerEventProcessLoop(this)
taskScheduler.setDAGScheduler(this)

private val outputCommitCoordinator = env.outputCommitCoordinator

// Called by TaskScheduler to report task's starting.
def taskStarted(task: Task[_], taskInfo: TaskInfo) {
eventProcessLoop.post(BeginEvent(task, taskInfo))
Expand Down Expand Up @@ -641,13 +645,13 @@ class DAGScheduler(
val split = rdd.partitions(job.partitions(0))
val taskContext = new TaskContextImpl(job.finalStage.id, job.partitions(0), taskAttemptId = 0,
attemptNumber = 0, runningLocally = true)
TaskContextHelper.setTaskContext(taskContext)
TaskContext.setTaskContext(taskContext)
try {
val result = job.func(taskContext, rdd.iterator(split, taskContext))
job.listener.taskSucceeded(0, result)
} finally {
taskContext.markTaskCompleted()
TaskContextHelper.unset()
TaskContext.unset()
}
} catch {
case e: Exception =>
Expand Down Expand Up @@ -710,9 +714,10 @@ class DAGScheduler(
// cancelling the stages because if the DAG scheduler is stopped, the entire application
// is in the process of getting stopped.
val stageFailedMessage = "Stage cancelled because SparkContext was shut down"
runningStages.foreach { stage =>
stage.latestInfo.stageFailed(stageFailedMessage)
listenerBus.post(SparkListenerStageCompleted(stage.latestInfo))
// The `toArray` here is necessary so that we don't iterate over `runningStages` while
// mutating it.
runningStages.toArray.foreach { stage =>
markStageAsFinished(stage, Some(stageFailedMessage))
}
listenerBus.post(SparkListenerJobEnd(job.jobId, clock.getTimeMillis(), JobFailed(error)))
}
Expand Down Expand Up @@ -887,10 +892,9 @@ class DAGScheduler(
new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.jobId, properties))
stage.latestInfo.submissionTime = Some(clock.getTimeMillis())
} else {
// Because we posted SparkListenerStageSubmitted earlier, we should post
// SparkListenerStageCompleted here in case there are no tasks to run.
outputCommitCoordinator.stageEnd(stage.id)
listenerBus.post(SparkListenerStageCompleted(stage.latestInfo))
// Because we posted SparkListenerStageSubmitted earlier, we should mark
// the stage as completed here in case there are no tasks to run
markStageAsFinished(stage, None)

val debugString = stage match {
case stage: ShuffleMapStage =>
Expand All @@ -902,7 +906,6 @@ class DAGScheduler(
s"Stage ${stage} is actually done; (partitions: ${stage.numPartitions})"
}
logDebug(debugString)
runningStages -= stage
}
}

Expand Down Expand Up @@ -968,22 +971,6 @@ class DAGScheduler(
}

val stage = stageIdToStage(task.stageId)

def markStageAsFinished(stage: Stage, errorMessage: Option[String] = None): Unit = {
val serviceTime = stage.latestInfo.submissionTime match {
case Some(t) => "%.03f".format((clock.getTimeMillis() - t) / 1000.0)
case _ => "Unknown"
}
if (errorMessage.isEmpty) {
logInfo("%s (%s) finished in %s s".format(stage, stage.name, serviceTime))
stage.latestInfo.completionTime = Some(clock.getTimeMillis())
} else {
stage.latestInfo.stageFailed(errorMessage.get)
logInfo("%s (%s) failed in %s s".format(stage, stage.name, serviceTime))
}
listenerBus.post(SparkListenerStageCompleted(stage.latestInfo))
runningStages -= stage
}
event.reason match {
case Success =>
listenerBus.post(SparkListenerTaskEnd(stageId, stage.latestInfo.attemptId, taskType,
Expand Down Expand Up @@ -1099,7 +1086,6 @@ class DAGScheduler(
logInfo(s"Marking $failedStage (${failedStage.name}) as failed " +
s"due to a fetch failure from $mapStage (${mapStage.name})")
markStageAsFinished(failedStage, Some(failureMessage))
runningStages -= failedStage
}

if (disallowStageRetryForTest) {
Expand Down Expand Up @@ -1215,6 +1201,26 @@ class DAGScheduler(
submitWaitingStages()
}

/**
* Marks a stage as finished and removes it from the list of running stages.
*/
private def markStageAsFinished(stage: Stage, errorMessage: Option[String] = None): Unit = {
val serviceTime = stage.latestInfo.submissionTime match {
case Some(t) => "%.03f".format((clock.getTimeMillis() - t) / 1000.0)
case _ => "Unknown"
}
if (errorMessage.isEmpty) {
logInfo("%s (%s) finished in %s s".format(stage, stage.name, serviceTime))
stage.latestInfo.completionTime = Some(clock.getTimeMillis())
} else {
stage.latestInfo.stageFailed(errorMessage.get)
logInfo("%s (%s) failed in %s s".format(stage, stage.name, serviceTime))
}
outputCommitCoordinator.stageEnd(stage.id)
listenerBus.post(SparkListenerStageCompleted(stage.latestInfo))
runningStages -= stage
}

/**
* Aborts all jobs depending on a particular Stage. This is called in response to a task set
* being canceled by the TaskScheduler. Use taskSetFailed() to inject this event from outside.
Expand Down Expand Up @@ -1264,8 +1270,7 @@ class DAGScheduler(
if (runningStages.contains(stage)) {
try { // cancelTasks will fail if a SchedulerBackend does not implement killTask
taskScheduler.cancelTasks(stageId, shouldInterruptThread)
stage.latestInfo.stageFailed(failureReason)
listenerBus.post(SparkListenerStageCompleted(stage.latestInfo))
markStageAsFinished(stage, Some(failureReason))
} catch {
case e: UnsupportedOperationException =>
logInfo(s"Could not cancel tasks for stage $stageId", e)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,13 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf) extends Logging {
private val authorizedCommittersByStage: CommittersByStageMap = mutable.Map()
private type CommittersByStageMap = mutable.Map[StageId, mutable.Map[PartitionId, TaskAttemptId]]

/**
* Returns whether the OutputCommitCoordinator's internal data structures are all empty.
*/
def isEmpty: Boolean = {
authorizedCommittersByStage.isEmpty
}

/**
* Called by tasks to ask whether they can commit their output to HDFS.
*
Expand Down
6 changes: 3 additions & 3 deletions core/src/main/scala/org/apache/spark/scheduler/Task.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import java.nio.ByteBuffer

import scala.collection.mutable.HashMap

import org.apache.spark.{TaskContextHelper, TaskContextImpl, TaskContext}
import org.apache.spark.{TaskContextImpl, TaskContext}
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.serializer.SerializerInstance
import org.apache.spark.util.ByteBufferInputStream
Expand Down Expand Up @@ -54,7 +54,7 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex
final def run(taskAttemptId: Long, attemptNumber: Int): T = {
context = new TaskContextImpl(stageId = stageId, partitionId = partitionId,
taskAttemptId = taskAttemptId, attemptNumber = attemptNumber, runningLocally = false)
TaskContextHelper.setTaskContext(context)
TaskContext.setTaskContext(context)
context.taskMetrics.setHostname(Utils.localHostName())
taskThread = Thread.currentThread()
if (_killed) {
Expand All @@ -64,7 +64,7 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex
runTask(context)
} finally {
context.markTaskCompleted()
TaskContextHelper.unset()
TaskContext.unset()
}
}

Expand Down
4 changes: 2 additions & 2 deletions core/src/test/scala/org/apache/spark/ShuffleSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -242,14 +242,14 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex
shuffleSpillCompress <- Set(true, false);
shuffleCompress <- Set(true, false)
) {
val conf = new SparkConf()
val myConf = conf.clone()
.setAppName("test")
.setMaster("local")
.set("spark.shuffle.spill.compress", shuffleSpillCompress.toString)
.set("spark.shuffle.compress", shuffleCompress.toString)
.set("spark.shuffle.memoryFraction", "0.001")
resetSparkContext()
sc = new SparkContext(conf)
sc = new SparkContext(myConf)
try {
sc.parallelize(0 until 100000).map(i => (i / 4, i)).groupByKey().collect()
} catch {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -783,6 +783,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar
assert(scheduler.runningStages.isEmpty)
assert(scheduler.shuffleToMapStage.isEmpty)
assert(scheduler.waitingStages.isEmpty)
assert(scheduler.outputCommitCoordinator.isEmpty)
}

// Nothing in this test should break if the task info's fields are null, but
Expand Down
2 changes: 1 addition & 1 deletion docs/ml-guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ from pyspark.ml.feature import HashingTF, Tokenizer
from pyspark.sql import Row, SQLContext

sc = SparkContext(appName="SimpleTextClassificationPipeline")
sqlCtx = SQLContext(sc)
sqlContext = SQLContext(sc)

# Prepare training documents, which are labeled.
LabeledDocument = Row("id", "text", "label")
Expand Down
4 changes: 2 additions & 2 deletions docs/sql-programming-guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -1642,15 +1642,15 @@ moved into the udf object in `SQLContext`.
<div data-lang="scala" markdown="1">
{% highlight java %}

sqlCtx.udf.register("strLen", (s: String) => s.length())
sqlContext.udf.register("strLen", (s: String) => s.length())

{% endhighlight %}
</div>

<div data-lang="java" markdown="1">
{% highlight java %}

sqlCtx.udf().register("strLen", (String s) -> { s.length(); });
sqlContext.udf().register("strLen", (String s) -> { s.length(); });

{% endhighlight %}
</div>
Expand Down
Loading

0 comments on commit 55be1f3

Please sign in to comment.