Skip to content

Commit

Permalink
Merge pull request alteryx#61 from kayousterhout/daemon_thread
Browse files Browse the repository at this point in the history
Unified daemon thread pools

As requested by @mateiz in an earlier pull request, this refactors various daemon thread pools to use a set of methods in utils.scala, and also changes the thread-pool-creation methods in utils.scala to use named thread pools for improved debugging.

(cherry picked from commit 983b83f)
Signed-off-by: Reynold Xin <rxin@apache.org>
  • Loading branch information
mateiz authored and rxin committed Oct 18, 2013
1 parent 2760055 commit b6ce111
Show file tree
Hide file tree
Showing 7 changed files with 29 additions and 38 deletions.
Expand Up @@ -326,7 +326,8 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal:
private var blocksInRequestBitVector = new BitSet(totalBlocks)

override def run() {
var threadPool = Utils.newDaemonFixedThreadPool(MultiTracker.MaxChatSlots)
var threadPool = Utils.newDaemonFixedThreadPool(
MultiTracker.MaxChatSlots, "Bit Torrent Chatter")

while (hasBlocks.get < totalBlocks) {
var numThreadsToCreate = 0
Expand Down Expand Up @@ -736,7 +737,7 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal:
private var setOfCompletedSources = Set[SourceInfo]()

override def run() {
var threadPool = Utils.newDaemonCachedThreadPool()
var threadPool = Utils.newDaemonCachedThreadPool("Bit torrent guide multiple requests")
var serverSocket: ServerSocket = null

serverSocket = new ServerSocket(0)
Expand Down Expand Up @@ -927,7 +928,8 @@ private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal:
class ServeMultipleRequests
extends Thread with Logging {
// Server at most MultiTracker.MaxChatSlots peers
var threadPool = Utils.newDaemonFixedThreadPool(MultiTracker.MaxChatSlots)
var threadPool = Utils.newDaemonFixedThreadPool(
MultiTracker.MaxChatSlots, "Bit torrent serve multiple requests")

override def run() {
var serverSocket = new ServerSocket(0)
Expand Down
Expand Up @@ -137,7 +137,7 @@ extends Logging {
class TrackMultipleValues
extends Thread with Logging {
override def run() {
var threadPool = Utils.newDaemonCachedThreadPool()
var threadPool = Utils.newDaemonCachedThreadPool("Track multiple values")
var serverSocket: ServerSocket = null

serverSocket = new ServerSocket(DriverTrackerPort)
Expand Down
Expand Up @@ -291,7 +291,7 @@ extends Broadcast[T](id) with Logging with Serializable {
private var setOfCompletedSources = Set[SourceInfo]()

override def run() {
var threadPool = Utils.newDaemonCachedThreadPool()
var threadPool = Utils.newDaemonCachedThreadPool("Tree broadcast guide multiple requests")
var serverSocket: ServerSocket = null

serverSocket = new ServerSocket(0)
Expand Down Expand Up @@ -493,7 +493,7 @@ extends Broadcast[T](id) with Logging with Serializable {
class ServeMultipleRequests
extends Thread with Logging {

var threadPool = Utils.newDaemonCachedThreadPool()
var threadPool = Utils.newDaemonCachedThreadPool("Tree broadcast serve multiple requests")

override def run() {
var serverSocket = new ServerSocket(0)
Expand Down
3 changes: 1 addition & 2 deletions core/src/main/scala/org/apache/spark/executor/Executor.scala
Expand Up @@ -121,8 +121,7 @@ private[spark] class Executor(
}

// Start worker thread pool
val threadPool = new ThreadPoolExecutor(
1, 128, 600, TimeUnit.SECONDS, new SynchronousQueue[Runnable], Utils.daemonThreadFactory)
val threadPool = Utils.newDaemonCachedThreadPool("Executor task launch worker")

// Maintains the list of running tasks.
private val runningTasks = new ConcurrentHashMap[Long, TaskRunner]
Expand Down
Expand Up @@ -79,7 +79,8 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
private val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)]
private val registerRequests = new SynchronizedQueue[SendingConnection]

implicit val futureExecContext = ExecutionContext.fromExecutor(Utils.newDaemonCachedThreadPool())
implicit val futureExecContext = ExecutionContext.fromExecutor(
Utils.newDaemonCachedThreadPool("Connection manager future execution context"))

private var onReceiveCallback: (BufferMessage, ConnectionManagerId) => Option[Message]= null

Expand Down
Expand Up @@ -24,33 +24,16 @@ import org.apache.spark._
import org.apache.spark.TaskState.TaskState
import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, TaskResult}
import org.apache.spark.serializer.SerializerInstance
import org.apache.spark.util.Utils

/**
* Runs a thread pool that deserializes and remotely fetches (if necessary) task results.
*/
private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: ClusterScheduler)
extends Logging {
private val MIN_THREADS = System.getProperty("spark.resultGetter.minThreads", "4").toInt
private val MAX_THREADS = System.getProperty("spark.resultGetter.maxThreads", "4").toInt
private val getTaskResultExecutor = new ThreadPoolExecutor(
MIN_THREADS,
MAX_THREADS,
0L,
TimeUnit.SECONDS,
new LinkedBlockingDeque[Runnable],
new ResultResolverThreadFactory)

class ResultResolverThreadFactory extends ThreadFactory {
private var counter = 0
private var PREFIX = "Result resolver thread"

override def newThread(r: Runnable): Thread = {
val thread = new Thread(r, "%s-%s".format(PREFIX, counter))
counter += 1
thread.setDaemon(true)
return thread
}
}
private val THREADS = System.getProperty("spark.resultGetter.threads", "4").toInt
private val getTaskResultExecutor = Utils.newDaemonFixedThreadPool(
THREADS, "Result resolver thread")

protected val serializer = new ThreadLocal[SerializerInstance] {
override def initialValue(): SerializerInstance = {
Expand Down
22 changes: 14 additions & 8 deletions core/src/main/scala/org/apache/spark/util/Utils.scala
Expand Up @@ -447,14 +447,17 @@ private[spark] object Utils extends Logging {
hostPortParseResults.get(hostPort)
}

private[spark] val daemonThreadFactory: ThreadFactory =
new ThreadFactoryBuilder().setDaemon(true).build()
private val daemonThreadFactoryBuilder: ThreadFactoryBuilder =
new ThreadFactoryBuilder().setDaemon(true)

/**
* Wrapper over newCachedThreadPool.
* Wrapper over newCachedThreadPool. Thread names are formatted as prefix-ID, where ID is a
* unique, sequentially assigned integer.
*/
def newDaemonCachedThreadPool(): ThreadPoolExecutor =
Executors.newCachedThreadPool(daemonThreadFactory).asInstanceOf[ThreadPoolExecutor]
def newDaemonCachedThreadPool(prefix: String): ThreadPoolExecutor = {
val threadFactory = daemonThreadFactoryBuilder.setNameFormat(prefix + "-%d").build()
Executors.newCachedThreadPool(threadFactory).asInstanceOf[ThreadPoolExecutor]
}

/**
* Return the string to tell how long has passed in seconds. The passing parameter should be in
Expand All @@ -465,10 +468,13 @@ private[spark] object Utils extends Logging {
}

/**
* Wrapper over newFixedThreadPool.
* Wrapper over newFixedThreadPool. Thread names are formatted as prefix-ID, where ID is a
* unique, sequentially assigned integer.
*/
def newDaemonFixedThreadPool(nThreads: Int): ThreadPoolExecutor =
Executors.newFixedThreadPool(nThreads, daemonThreadFactory).asInstanceOf[ThreadPoolExecutor]
def newDaemonFixedThreadPool(nThreads: Int, prefix: String): ThreadPoolExecutor = {
val threadFactory = daemonThreadFactoryBuilder.setNameFormat(prefix + "-%d").build()
Executors.newFixedThreadPool(nThreads, threadFactory).asInstanceOf[ThreadPoolExecutor]
}

private def listFilesSafely(file: File): Seq[File] = {
val files = file.listFiles()
Expand Down

0 comments on commit b6ce111

Please sign in to comment.