diff --git a/core/pom.xml b/core/pom.xml index 7c60cf10c3dc2..6d8be37037729 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -150,7 +150,7 @@ org.json4s json4s-jackson_${scala.binary.version} - 3.2.6 + 3.2.10 colt diff --git a/core/src/main/scala/org/apache/spark/HttpFileServer.scala b/core/src/main/scala/org/apache/spark/HttpFileServer.scala index 0e3750fdde415..edc3889c9ae51 100644 --- a/core/src/main/scala/org/apache/spark/HttpFileServer.scala +++ b/core/src/main/scala/org/apache/spark/HttpFileServer.scala @@ -23,7 +23,10 @@ import com.google.common.io.Files import org.apache.spark.util.Utils -private[spark] class HttpFileServer(securityManager: SecurityManager) extends Logging { +private[spark] class HttpFileServer( + securityManager: SecurityManager, + requestedPort: Int = 0) + extends Logging { var baseDir : File = null var fileDir : File = null @@ -38,7 +41,7 @@ private[spark] class HttpFileServer(securityManager: SecurityManager) extends Lo fileDir.mkdir() jarDir.mkdir() logInfo("HTTP File server directory is " + baseDir) - httpServer = new HttpServer(baseDir, securityManager) + httpServer = new HttpServer(baseDir, securityManager, requestedPort, "HTTP file server") httpServer.start() serverUri = httpServer.uri logDebug("HTTP file server started at: " + serverUri) diff --git a/core/src/main/scala/org/apache/spark/HttpServer.scala b/core/src/main/scala/org/apache/spark/HttpServer.scala index 7e9b517f901a2..912558d0cab7d 100644 --- a/core/src/main/scala/org/apache/spark/HttpServer.scala +++ b/core/src/main/scala/org/apache/spark/HttpServer.scala @@ -21,7 +21,7 @@ import java.io.File import org.eclipse.jetty.util.security.{Constraint, Password} import org.eclipse.jetty.security.authentication.DigestAuthenticator -import org.eclipse.jetty.security.{ConstraintMapping, ConstraintSecurityHandler, HashLoginService, SecurityHandler} +import org.eclipse.jetty.security.{ConstraintMapping, ConstraintSecurityHandler, HashLoginService} import org.eclipse.jetty.server.Server import org.eclipse.jetty.server.bio.SocketConnector @@ -41,48 +41,68 @@ private[spark] class ServerStateException(message: String) extends Exception(mes * as well as classes created by the interpreter when the user types in code. This is just a wrapper * around a Jetty server. */ -private[spark] class HttpServer(resourceBase: File, securityManager: SecurityManager) - extends Logging { +private[spark] class HttpServer( + resourceBase: File, + securityManager: SecurityManager, + requestedPort: Int = 0, + serverName: String = "HTTP server") + extends Logging { + private var server: Server = null - private var port: Int = -1 + private var port: Int = requestedPort def start() { if (server != null) { throw new ServerStateException("Server is already started") } else { logInfo("Starting HTTP Server") - server = new Server() - val connector = new SocketConnector - connector.setMaxIdleTime(60*1000) - connector.setSoLingerTime(-1) - connector.setPort(0) - server.addConnector(connector) - - val threadPool = new QueuedThreadPool - threadPool.setDaemon(true) - server.setThreadPool(threadPool) - val resHandler = new ResourceHandler - resHandler.setResourceBase(resourceBase.getAbsolutePath) - - val handlerList = new HandlerList - handlerList.setHandlers(Array(resHandler, new DefaultHandler)) - - if (securityManager.isAuthenticationEnabled()) { - logDebug("HttpServer is using security") - val sh = setupSecurityHandler(securityManager) - // make sure we go through security handler to get resources - sh.setHandler(handlerList) - server.setHandler(sh) - } else { - logDebug("HttpServer is not using security") - server.setHandler(handlerList) - } - - server.start() - port = server.getConnectors()(0).getLocalPort() + val (actualServer, actualPort) = + Utils.startServiceOnPort[Server](requestedPort, doStart, serverName) + server = actualServer + port = actualPort } } + /** + * Actually start the HTTP server on the given port. + * + * Note that this is only best effort in the sense that we may end up binding to a nearby port + * in the event of port collision. Return the bound server and the actual port used. + */ + private def doStart(startPort: Int): (Server, Int) = { + val server = new Server() + val connector = new SocketConnector + connector.setMaxIdleTime(60 * 1000) + connector.setSoLingerTime(-1) + connector.setPort(startPort) + server.addConnector(connector) + + val threadPool = new QueuedThreadPool + threadPool.setDaemon(true) + server.setThreadPool(threadPool) + val resHandler = new ResourceHandler + resHandler.setResourceBase(resourceBase.getAbsolutePath) + + val handlerList = new HandlerList + handlerList.setHandlers(Array(resHandler, new DefaultHandler)) + + if (securityManager.isAuthenticationEnabled()) { + logDebug("HttpServer is using security") + val sh = setupSecurityHandler(securityManager) + // make sure we go through security handler to get resources + sh.setHandler(handlerList) + server.setHandler(sh) + } else { + logDebug("HttpServer is not using security") + server.setHandler(handlerList) + } + + server.start() + val actualPort = server.getConnectors()(0).getLocalPort + + (server, actualPort) + } + /** * Setup Jetty to the HashLoginService using a single user with our * shared secret. Configure it to use DIGEST-MD5 authentication so that the password @@ -134,7 +154,7 @@ private[spark] class HttpServer(resourceBase: File, securityManager: SecurityMan if (server == null) { throw new ServerStateException("Server is not started") } else { - return "http://" + Utils.localIpAddress + ":" + port + "http://" + Utils.localIpAddress + ":" + port } } } diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index cce7a23d3b9fc..13f0bff7ee507 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -323,6 +323,14 @@ private[spark] object SparkConf { * the scheduler, while the rest of the spark configs can be inherited from the driver later. */ def isExecutorStartupConf(name: String): Boolean = { - isAkkaConf(name) || name.startsWith("spark.akka") || name.startsWith("spark.auth") + isAkkaConf(name) || + name.startsWith("spark.akka") || + name.startsWith("spark.auth") || + isSparkPortConf(name) } + + /** + * Return whether the given config is a Spark port config. + */ + def isSparkPortConf(name: String): Boolean = name.startsWith("spark.") && name.endsWith(".port") } diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index dd8e4ac66dc66..9d4edeb6d96cf 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -22,7 +22,6 @@ import java.net.Socket import scala.collection.JavaConversions._ import scala.collection.mutable -import scala.concurrent.Await import scala.util.Properties import akka.actor._ @@ -151,10 +150,10 @@ object SparkEnv extends Logging { val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, port, conf = conf, securityManager = securityManager) - // Bit of a hack: If this is the driver and our port was 0 (meaning bind to any free port), - // figure out which port number Akka actually bound to and set spark.driver.port to it. - if (isDriver && port == 0) { - conf.set("spark.driver.port", boundPort.toString) + // Figure out which port Akka actually bound to in case the original port is 0 or occupied. + // This is so that we tell the executors the correct port to connect to. + if (isDriver) { + conf.set("spark.driver.port", boundPort.toString) } // Create an instance of the class named by the given Java system property, or by @@ -222,7 +221,8 @@ object SparkEnv extends Logging { val httpFileServer = if (isDriver) { - val server = new HttpFileServer(securityManager) + val fileServerPort = conf.getInt("spark.fileserver.port", 0) + val server = new HttpFileServer(securityManager, fileServerPort) server.initialize() conf.set("spark.fileserver.uri", server.serverUri) server diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index 487456467b23b..942dc7d7eac87 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -152,7 +152,8 @@ private[broadcast] object HttpBroadcast extends Logging { private def createServer(conf: SparkConf) { broadcastDir = Utils.createTempDir(Utils.getLocalDir(conf)) - server = new HttpServer(broadcastDir, securityManager) + val broadcastPort = conf.getInt("spark.broadcast.port", 0) + server = new HttpServer(broadcastDir, securityManager, broadcastPort, "HTTP broadcast server") server.start() serverUri = server.uri logInfo("Broadcast server started at " + serverUri) diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index 17c507af2652d..c07003784e8ac 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -155,8 +155,6 @@ object Client { conf.set("akka.loglevel", driverArgs.logLevel.toString.replace("WARN", "WARNING")) Logger.getRootLogger.setLevel(driverArgs.logLevel) - // TODO: See if we can initialize akka so return messages are sent back using the same TCP - // flow. Else, this (sadly) requires the DriverClient be routable from the Master. val (actorSystem, _) = AkkaUtils.createActorSystem( "driverClient", Utils.localHostName(), 0, conf, new SecurityManager(conf)) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala index 16aa0493370dd..d86ec1e03e45c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala @@ -28,7 +28,7 @@ import org.apache.spark.util.AkkaUtils */ private[spark] class MasterWebUI(val master: Master, requestedPort: Int) - extends WebUI(master.securityMgr, requestedPort, master.conf) with Logging { + extends WebUI(master.securityMgr, requestedPort, master.conf, name = "MasterUI") with Logging { val masterActorRef = master.self val timeout = AkkaUtils.askTimeout(master.conf) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala index a9f531e9e4cae..47fbda600bea7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala @@ -22,6 +22,7 @@ import javax.servlet.http.HttpServletRequest import org.apache.spark.{Logging, SparkConf} import org.apache.spark.deploy.worker.Worker +import org.apache.spark.deploy.worker.ui.WorkerWebUI._ import org.apache.spark.ui.{SparkUI, WebUI} import org.apache.spark.ui.JettyUtils._ import org.apache.spark.util.AkkaUtils @@ -34,7 +35,7 @@ class WorkerWebUI( val worker: Worker, val workDir: File, port: Option[Int] = None) - extends WebUI(worker.securityMgr, WorkerWebUI.getUIPort(port, worker.conf), worker.conf) + extends WebUI(worker.securityMgr, getUIPort(port, worker.conf), worker.conf, name = "WorkerUI") with Logging { val timeout = AkkaUtils.askTimeout(worker.conf) diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index af736de405397..1f46a0f176490 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -115,8 +115,9 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { // Bootstrap to fetch the driver's Spark properties. val executorConf = new SparkConf + val port = executorConf.getInt("spark.executor.port", 0) val (fetcher, _) = AkkaUtils.createActorSystem( - "driverPropsFetcher", hostname, 0, executorConf, new SecurityManager(executorConf)) + "driverPropsFetcher", hostname, port, executorConf, new SecurityManager(executorConf)) val driver = fetcher.actorSelection(driverUrl) val timeout = AkkaUtils.askTimeout(executorConf) val fut = Patterns.ask(driver, RetrieveSparkProps, timeout) @@ -126,7 +127,7 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { // Create a new ActorSystem using driver's Spark properties to run the backend. val driverConf = new SparkConf().setAll(props) val (actorSystem, boundPort) = AkkaUtils.createActorSystem( - "sparkExecutor", hostname, 0, driverConf, new SecurityManager(driverConf)) + "sparkExecutor", hostname, port, driverConf, new SecurityManager(driverConf)) // set it val sparkHostPort = hostname + ":" + boundPort actorSystem.actorOf( diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala index 566e8a4aaa1d2..4c00225280cce 100644 --- a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala @@ -38,8 +38,12 @@ import scala.language.postfixOps import org.apache.spark._ import org.apache.spark.util.{SystemClock, Utils} -private[spark] class ConnectionManager(port: Int, conf: SparkConf, - securityManager: SecurityManager) extends Logging { +private[spark] class ConnectionManager( + port: Int, + conf: SparkConf, + securityManager: SecurityManager, + name: String = "Connection manager") + extends Logging { class MessageStatus( val message: Message, @@ -105,7 +109,11 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, serverChannel.socket.setReuseAddress(true) serverChannel.socket.setReceiveBufferSize(256 * 1024) - serverChannel.socket.bind(new InetSocketAddress(port)) + private def startService(port: Int): (ServerSocketChannel, Int) = { + serverChannel.socket.bind(new InetSocketAddress(port)) + (serverChannel, serverChannel.socket.getLocalPort) + } + Utils.startServiceOnPort[ServerSocketChannel](port, startService, name) serverChannel.register(selector, SelectionKey.OP_ACCEPT) val id = new ConnectionManagerId(Utils.localHostName, serverChannel.socket.getLocalPort) diff --git a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala index a76a070b5b863..8947e66f4577c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala @@ -96,17 +96,23 @@ class JdbcRDD[T: ClassTag]( override def close() { try { - if (null != rs && ! rs.isClosed()) rs.close() + if (null != rs && ! rs.isClosed()) { + rs.close() + } } catch { case e: Exception => logWarning("Exception closing resultset", e) } try { - if (null != stmt && ! stmt.isClosed()) stmt.close() + if (null != stmt && ! stmt.isClosed()) { + stmt.close() + } } catch { case e: Exception => logWarning("Exception closing statement", e) } try { - if (null != conn && ! stmt.isClosed()) conn.close() + if (null != conn && ! conn.isClosed()) { + conn.close() + } logInfo("closed connection") } catch { case e: Exception => logWarning("Exception closing connection", e) @@ -120,3 +126,4 @@ object JdbcRDD { Array.tabulate[Object](rs.getMetaData.getColumnCount)(i => rs.getObject(i + 1)) } } + diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskLocality.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskLocality.scala index eb920ab0c0b67..f176d09816f5e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskLocality.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskLocality.scala @@ -22,7 +22,7 @@ import org.apache.spark.annotation.DeveloperApi @DeveloperApi object TaskLocality extends Enumeration { // Process local is expected to be used ONLY within TaskSetManager for now. - val PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY = Value + val PROCESS_LOCAL, NODE_LOCAL, NO_PREF, RACK_LOCAL, ANY = Value type TaskLocality = Value diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index d2f764fc22f54..6c0d1b2752a81 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -89,11 +89,11 @@ private[spark] class TaskSchedulerImpl( // The set of executors we have on each host; this is used to compute hostsAlive, which // in turn is used to decide when we can attain data locality on a given host - private val executorsByHost = new HashMap[String, HashSet[String]] + protected val executorsByHost = new HashMap[String, HashSet[String]] protected val hostsByRack = new HashMap[String, HashSet[String]] - private val executorIdToHost = new HashMap[String, String] + protected val executorIdToHost = new HashMap[String, String] // Listener object to pass upcalls into var dagScheduler: DAGScheduler = null @@ -249,6 +249,7 @@ private[spark] class TaskSchedulerImpl( // Take each TaskSet in our scheduling order, and then offer it each node in increasing order // of locality levels so that it gets a chance to launch local tasks on all of them. + // NOTE: the preferredLocality order: PROCESS_LOCAL, NODE_LOCAL, NO_PREF, RACK_LOCAL, ANY var launchedTask = false for (taskSet <- sortedTaskSets; maxLocality <- taskSet.myLocalityLevels) { do { @@ -265,7 +266,7 @@ private[spark] class TaskSchedulerImpl( activeExecutorIds += execId executorsByHost(host) += execId availableCpus(i) -= CPUS_PER_TASK - assert (availableCpus(i) >= 0) + assert(availableCpus(i) >= 0) launchedTask = true } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 8b5e8cb802a45..20a4bd12f93f6 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -79,6 +79,7 @@ private[spark] class TaskSetManager( private val numFailures = new Array[Int](numTasks) // key is taskId, value is a Map of executor id to when it failed private val failedExecutors = new HashMap[Int, HashMap[String, Long]]() + val taskAttempts = Array.fill[List[TaskInfo]](numTasks)(Nil) var tasksSuccessful = 0 @@ -179,26 +180,17 @@ private[spark] class TaskSetManager( } } - var hadAliveLocations = false for (loc <- tasks(index).preferredLocations) { for (execId <- loc.executorId) { addTo(pendingTasksForExecutor.getOrElseUpdate(execId, new ArrayBuffer)) } - if (sched.hasExecutorsAliveOnHost(loc.host)) { - hadAliveLocations = true - } addTo(pendingTasksForHost.getOrElseUpdate(loc.host, new ArrayBuffer)) for (rack <- sched.getRackForHost(loc.host)) { addTo(pendingTasksForRack.getOrElseUpdate(rack, new ArrayBuffer)) - if(sched.hasHostAliveOnRack(rack)){ - hadAliveLocations = true - } } } - if (!hadAliveLocations) { - // Even though the task might've had preferred locations, all of those hosts or executors - // are dead; put it in the no-prefs list so we can schedule it elsewhere right away. + if (tasks(index).preferredLocations == Nil) { addTo(pendingTasksWithNoPrefs) } @@ -239,7 +231,6 @@ private[spark] class TaskSetManager( */ private def findTaskFromList(execId: String, list: ArrayBuffer[Int]): Option[Int] = { var indexOffset = list.size - while (indexOffset > 0) { indexOffset -= 1 val index = list(indexOffset) @@ -288,12 +279,12 @@ private[spark] class TaskSetManager( !hasAttemptOnHost(index, host) && !executorIsBlacklisted(execId, index) if (!speculatableTasks.isEmpty) { - // Check for process-local or preference-less tasks; note that tasks can be process-local + // Check for process-local tasks; note that tasks can be process-local // on multiple nodes when we replicate cached blocks, as in Spark Streaming for (index <- speculatableTasks if canRunOnHost(index)) { val prefs = tasks(index).preferredLocations val executors = prefs.flatMap(_.executorId) - if (prefs.size == 0 || executors.contains(execId)) { + if (executors.contains(execId)) { speculatableTasks -= index return Some((index, TaskLocality.PROCESS_LOCAL)) } @@ -310,6 +301,17 @@ private[spark] class TaskSetManager( } } + // Check for no-preference tasks + if (TaskLocality.isAllowed(locality, TaskLocality.NO_PREF)) { + for (index <- speculatableTasks if canRunOnHost(index)) { + val locations = tasks(index).preferredLocations + if (locations.size == 0) { + speculatableTasks -= index + return Some((index, TaskLocality.PROCESS_LOCAL)) + } + } + } + // Check for rack-local tasks if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) { for (rack <- sched.getRackForHost(host)) { @@ -341,20 +343,27 @@ private[spark] class TaskSetManager( * * @return An option containing (task index within the task set, locality, is speculative?) */ - private def findTask(execId: String, host: String, locality: TaskLocality.Value) + private def findTask(execId: String, host: String, maxLocality: TaskLocality.Value) : Option[(Int, TaskLocality.Value, Boolean)] = { for (index <- findTaskFromList(execId, getPendingTasksForExecutor(execId))) { return Some((index, TaskLocality.PROCESS_LOCAL, false)) } - if (TaskLocality.isAllowed(locality, TaskLocality.NODE_LOCAL)) { + if (TaskLocality.isAllowed(maxLocality, TaskLocality.NODE_LOCAL)) { for (index <- findTaskFromList(execId, getPendingTasksForHost(host))) { return Some((index, TaskLocality.NODE_LOCAL, false)) } } - if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) { + if (TaskLocality.isAllowed(maxLocality, TaskLocality.NO_PREF)) { + // Look for noPref tasks after NODE_LOCAL for minimize cross-rack traffic + for (index <- findTaskFromList(execId, pendingTasksWithNoPrefs)) { + return Some((index, TaskLocality.PROCESS_LOCAL, false)) + } + } + + if (TaskLocality.isAllowed(maxLocality, TaskLocality.RACK_LOCAL)) { for { rack <- sched.getRackForHost(host) index <- findTaskFromList(execId, getPendingTasksForRack(rack)) @@ -363,25 +372,27 @@ private[spark] class TaskSetManager( } } - // Look for no-pref tasks after rack-local tasks since they can run anywhere. - for (index <- findTaskFromList(execId, pendingTasksWithNoPrefs)) { - return Some((index, TaskLocality.PROCESS_LOCAL, false)) - } - - if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) { + if (TaskLocality.isAllowed(maxLocality, TaskLocality.ANY)) { for (index <- findTaskFromList(execId, allPendingTasks)) { return Some((index, TaskLocality.ANY, false)) } } - // Finally, if all else has failed, find a speculative task - findSpeculativeTask(execId, host, locality).map { case (taskIndex, allowedLocality) => - (taskIndex, allowedLocality, true) - } + // find a speculative task if all others tasks have been scheduled + findSpeculativeTask(execId, host, maxLocality).map { + case (taskIndex, allowedLocality) => (taskIndex, allowedLocality, true)} } /** * Respond to an offer of a single executor from the scheduler by finding a task + * + * NOTE: this function is either called with a maxLocality which + * would be adjusted by delay scheduling algorithm or it will be with a special + * NO_PREF locality which will be not modified + * + * @param execId the executor Id of the offered resource + * @param host the host Id of the offered resource + * @param maxLocality the maximum locality we want to schedule the tasks at */ def resourceOffer( execId: String, @@ -392,9 +403,14 @@ private[spark] class TaskSetManager( if (!isZombie) { val curTime = clock.getTime() - var allowedLocality = getAllowedLocalityLevel(curTime) - if (allowedLocality > maxLocality) { - allowedLocality = maxLocality // We're not allowed to search for farther-away tasks + var allowedLocality = maxLocality + + if (maxLocality != TaskLocality.NO_PREF) { + allowedLocality = getAllowedLocalityLevel(curTime) + if (allowedLocality > maxLocality) { + // We're not allowed to search for farther-away tasks + allowedLocality = maxLocality + } } findTask(execId, host, allowedLocality) match { @@ -410,8 +426,11 @@ private[spark] class TaskSetManager( taskInfos(taskId) = info taskAttempts(index) = info :: taskAttempts(index) // Update our locality level for delay scheduling - currentLocalityIndex = getLocalityIndex(taskLocality) - lastLaunchTime = curTime + // NO_PREF will not affect the variables related to delay scheduling + if (maxLocality != TaskLocality.NO_PREF) { + currentLocalityIndex = getLocalityIndex(taskLocality) + lastLaunchTime = curTime + } // Serialize and return the task val startTime = clock.getTime() // We rely on the DAGScheduler to catch non-serializable closures and RDDs, so in here @@ -639,8 +658,7 @@ private[spark] class TaskSetManager( override def executorLost(execId: String, host: String) { logInfo("Re-queueing tasks for " + execId + " from TaskSet " + taskSet.id) - // Re-enqueue pending tasks for this host based on the status of the cluster -- for example, a - // task that used to have locations on only this host might now go to the no-prefs list. Note + // Re-enqueue pending tasks for this host based on the status of the cluster. Note // that it's okay if we add a task to the same queue twice (if it had multiple preferred // locations), because findTaskFromList will skip already-running tasks. for (index <- getPendingTasksForExecutor(execId)) { @@ -671,6 +689,9 @@ private[spark] class TaskSetManager( for ((tid, info) <- taskInfos if info.running && info.executorId == execId) { handleFailedTask(tid, TaskState.FAILED, ExecutorLostFailure) } + // recalculate valid locality levels and waits when executor is lost + myLocalityLevels = computeValidLocalityLevels() + localityWaits = myLocalityLevels.map(getLocalityWait) } /** @@ -722,17 +743,17 @@ private[spark] class TaskSetManager( conf.get("spark.locality.wait.node", defaultWait).toLong case TaskLocality.RACK_LOCAL => conf.get("spark.locality.wait.rack", defaultWait).toLong - case TaskLocality.ANY => - 0L + case _ => 0L } } /** * Compute the locality levels used in this TaskSet. Assumes that all tasks have already been * added to queues using addPendingTask. + * */ private def computeValidLocalityLevels(): Array[TaskLocality.TaskLocality] = { - import TaskLocality.{PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY} + import TaskLocality.{PROCESS_LOCAL, NODE_LOCAL, NO_PREF, RACK_LOCAL, ANY} val levels = new ArrayBuffer[TaskLocality.TaskLocality] if (!pendingTasksForExecutor.isEmpty && getLocalityWait(PROCESS_LOCAL) != 0 && pendingTasksForExecutor.keySet.exists(sched.isExecutorAlive(_))) { @@ -742,6 +763,9 @@ private[spark] class TaskSetManager( pendingTasksForHost.keySet.exists(sched.hasExecutorsAliveOnHost(_))) { levels += NODE_LOCAL } + if (!pendingTasksWithNoPrefs.isEmpty) { + levels += NO_PREF + } if (!pendingTasksForRack.isEmpty && getLocalityWait(RACK_LOCAL) != 0 && pendingTasksForRack.keySet.exists(sched.hasHostAliveOnRack(_))) { levels += RACK_LOCAL @@ -751,20 +775,7 @@ private[spark] class TaskSetManager( levels.toArray } - // Re-compute pendingTasksWithNoPrefs since new preferred locations may become available def executorAdded() { - def newLocAvail(index: Int): Boolean = { - for (loc <- tasks(index).preferredLocations) { - if (sched.hasExecutorsAliveOnHost(loc.host) || - (sched.getRackForHost(loc.host).isDefined && - sched.hasHostAliveOnRack(sched.getRackForHost(loc.host).get))) { - return true - } - } - false - } - logInfo("Re-computing pending task lists.") - pendingTasksWithNoPrefs = pendingTasksWithNoPrefs.filter(!newLocAvail(_)) myLocalityLevels = computeValidLocalityLevels() localityWaits = myLocalityLevels.map(getLocalityWait) } diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index 9a356d0dbaf17..24db2f287a47b 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -40,7 +40,7 @@ private[spark] class SortShuffleWriter[K, V, C]( private val ser = Serializer.getSerializer(dep.serializer.orNull) private val conf = SparkEnv.get.conf - private val fileBufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 100) * 1024 + private val fileBufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 32) * 1024 private var sorter: ExternalSorter[K, V, _] = null private var outputFile: File = null diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index c0a06017945f0..3876cf43e2a7d 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -60,10 +60,12 @@ private[spark] class BlockManager( mapOutputTracker: MapOutputTracker) extends Logging { + private val port = conf.getInt("spark.blockManager.port", 0) val shuffleBlockManager = new ShuffleBlockManager(this) val diskBlockManager = new DiskBlockManager(shuffleBlockManager, conf.get("spark.local.dir", System.getProperty("java.io.tmpdir"))) - val connectionManager = new ConnectionManager(0, conf, securityManager) + val connectionManager = + new ConnectionManager(port, conf, securityManager, "Connection manager for block manager") implicit val futureExecContext = connectionManager.futureExecContext diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala index 28aa35bc7e147..f9fdffae8bd8f 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala @@ -73,7 +73,7 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging { val sortBasedShuffle = conf.get("spark.shuffle.manager", "") == classOf[SortShuffleManager].getName - private val bufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 100) * 1024 + private val bufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 32) * 1024 /** * Contains all the state related to a particular shuffle. This includes a pool of unused diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index a2535e3c1c41f..29e9cf947856f 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -174,40 +174,32 @@ private[spark] object JettyUtils extends Logging { hostName: String, port: Int, handlers: Seq[ServletContextHandler], - conf: SparkConf): ServerInfo = { + conf: SparkConf, + serverName: String = ""): ServerInfo = { val collection = new ContextHandlerCollection collection.setHandlers(handlers.toArray) addFilters(handlers, conf) - @tailrec + // Bind to the given port, or throw a java.net.BindException if the port is occupied def connect(currentPort: Int): (Server, Int) = { val server = new Server(new InetSocketAddress(hostName, currentPort)) val pool = new QueuedThreadPool pool.setDaemon(true) server.setThreadPool(pool) server.setHandler(collection) - - Try { + try { server.start() - } match { - case s: Success[_] => - (server, server.getConnectors.head.getLocalPort) - case f: Failure[_] => - val nextPort = (currentPort + 1) % 65536 + (server, server.getConnectors.head.getLocalPort) + } catch { + case e: Exception => server.stop() pool.stop() - val msg = s"Failed to create UI on port $currentPort. Trying again on port $nextPort." - if (f.toString.contains("Address already in use")) { - logWarning(s"$msg - $f") - } else { - logError(msg, f.exception) - } - connect(nextPort) + throw e } } - val (server, boundPort) = connect(port) + val (server, boundPort) = Utils.startServiceOnPort[Server](port, connect, serverName) ServerInfo(server, boundPort, collection) } diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index 097a1b81e1dd1..6c788a37dc70b 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -36,7 +36,7 @@ private[spark] class SparkUI( val listenerBus: SparkListenerBus, var appName: String, val basePath: String = "") - extends WebUI(securityManager, SparkUI.getUIPort(conf), conf, basePath) + extends WebUI(securityManager, SparkUI.getUIPort(conf), conf, basePath, "SparkUI") with Logging { def this(sc: SparkContext) = this(sc, sc.conf, sc.env.securityManager, sc.listenerBus, sc.appName) diff --git a/core/src/main/scala/org/apache/spark/ui/WebUI.scala b/core/src/main/scala/org/apache/spark/ui/WebUI.scala index 856273e1d4e21..5f52f95088007 100644 --- a/core/src/main/scala/org/apache/spark/ui/WebUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/WebUI.scala @@ -39,7 +39,8 @@ private[spark] abstract class WebUI( securityManager: SecurityManager, port: Int, conf: SparkConf, - basePath: String = "") + basePath: String = "", + name: String = "") extends Logging { protected val tabs = ArrayBuffer[WebUITab]() @@ -97,7 +98,7 @@ private[spark] abstract class WebUI( def bind() { assert(!serverInfo.isDefined, "Attempted to bind %s more than once!".format(className)) try { - serverInfo = Some(startJettyServer("0.0.0.0", port, handlers, conf)) + serverInfo = Some(startJettyServer("0.0.0.0", port, handlers, conf, name)) logInfo("Started %s at http://%s:%d".format(className, publicHostName, boundPort)) } catch { case e: Exception => diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala index feafd654e9e71..d6afb73b74242 100644 --- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala @@ -21,7 +21,7 @@ import scala.collection.JavaConversions.mapAsJavaMap import scala.concurrent.Await import scala.concurrent.duration.{Duration, FiniteDuration} -import akka.actor.{Actor, ActorRef, ActorSystem, ExtendedActorSystem} +import akka.actor.{ActorRef, ActorSystem, ExtendedActorSystem} import akka.pattern.ask import com.typesafe.config.ConfigFactory @@ -44,14 +44,28 @@ private[spark] object AkkaUtils extends Logging { * If indestructible is set to true, the Actor System will continue running in the event * of a fatal exception. This is used by [[org.apache.spark.executor.Executor]]. */ - def createActorSystem(name: String, host: String, port: Int, - conf: SparkConf, securityManager: SecurityManager): (ActorSystem, Int) = { + def createActorSystem( + name: String, + host: String, + port: Int, + conf: SparkConf, + securityManager: SecurityManager): (ActorSystem, Int) = { + val startService: Int => (ActorSystem, Int) = { actualPort => + doCreateActorSystem(name, host, actualPort, conf, securityManager) + } + Utils.startServiceOnPort(port, startService, name) + } + + private def doCreateActorSystem( + name: String, + host: String, + port: Int, + conf: SparkConf, + securityManager: SecurityManager): (ActorSystem, Int) = { val akkaThreads = conf.getInt("spark.akka.threads", 4) val akkaBatchSize = conf.getInt("spark.akka.batchSize", 15) - val akkaTimeout = conf.getInt("spark.akka.timeout", 100) - val akkaFrameSize = maxFrameSizeBytes(conf) val akkaLogLifecycleEvents = conf.getBoolean("spark.akka.logLifecycleEvents", false) val lifecycleEvents = if (akkaLogLifecycleEvents) "on" else "off" diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 30073a82857d2..c60be4f8a11d2 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -18,7 +18,7 @@ package org.apache.spark.util import java.io._ -import java.net.{InetAddress, Inet4Address, NetworkInterface, URI, URL, URLConnection} +import java.net._ import java.nio.ByteBuffer import java.util.{Locale, Random, UUID} import java.util.concurrent.{ThreadFactory, ConcurrentHashMap, Executors, ThreadPoolExecutor} @@ -1331,4 +1331,75 @@ private[spark] object Utils extends Logging { .map { case (k, v) => s"-D$k=$v" } } + /** + * Default number of retries in binding to a port. + */ + val portMaxRetries: Int = { + if (sys.props.contains("spark.testing")) { + // Set a higher number of retries for tests... + sys.props.get("spark.ports.maxRetries").map(_.toInt).getOrElse(100) + } else { + Option(SparkEnv.get) + .flatMap(_.conf.getOption("spark.ports.maxRetries")) + .map(_.toInt) + .getOrElse(16) + } + } + + /** + * Attempt to start a service on the given port, or fail after a number of attempts. + * Each subsequent attempt uses 1 + the port used in the previous attempt (unless the port is 0). + * + * @param startPort The initial port to start the service on. + * @param maxRetries Maximum number of retries to attempt. + * A value of 3 means attempting ports n, n+1, n+2, and n+3, for example. + * @param startService Function to start service on a given port. + * This is expected to throw java.net.BindException on port collision. + */ + def startServiceOnPort[T]( + startPort: Int, + startService: Int => (T, Int), + serviceName: String = "", + maxRetries: Int = portMaxRetries): (T, Int) = { + val serviceString = if (serviceName.isEmpty) "" else s" '$serviceName'" + for (offset <- 0 to maxRetries) { + // Do not increment port if startPort is 0, which is treated as a special port + val tryPort = if (startPort == 0) startPort else (startPort + offset) % 65536 + try { + val (service, port) = startService(tryPort) + logInfo(s"Successfully started service$serviceString on port $port.") + return (service, port) + } catch { + case e: Exception if isBindCollision(e) => + if (offset >= maxRetries) { + val exceptionMessage = + s"${e.getMessage}: Service$serviceString failed after $maxRetries retries!" + val exception = new BindException(exceptionMessage) + // restore original stack trace + exception.setStackTrace(e.getStackTrace) + throw exception + } + logWarning(s"Service$serviceString could not bind on port $tryPort. " + + s"Attempting port ${tryPort + 1}.") + } + } + // Should never happen + throw new SparkException(s"Failed to start service$serviceString on port $startPort") + } + + /** + * Return whether the exception is caused by an address-port collision when binding. + */ + def isBindCollision(exception: Throwable): Boolean = { + exception match { + case e: BindException => + if (e.getMessage != null && e.getMessage.contains("Address already in use")) { + return true + } + isBindCollision(e.getCause) + case e: Exception => isBindCollision(e.getCause) + case _ => false + } + } + } diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index cc0423856cefb..260a5c3888aa7 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -101,7 +101,7 @@ class ExternalAppendOnlyMap[K, V, C]( private var _memoryBytesSpilled = 0L private var _diskBytesSpilled = 0L - private val fileBufferSize = sparkConf.getInt("spark.shuffle.file.buffer.kb", 100) * 1024 + private val fileBufferSize = sparkConf.getInt("spark.shuffle.file.buffer.kb", 32) * 1024 private val keyComparator = new HashComparator[K] private val ser = serializer.newInstance() diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 101c83b264f63..3f93afd57b3ad 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -84,7 +84,7 @@ private[spark] class ExternalSorter[K, V, C]( private val conf = SparkEnv.get.conf private val spillingEnabled = conf.getBoolean("spark.shuffle.spill", true) - private val fileBufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 100) * 1024 + private val fileBufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 32) * 1024 // Size of object batches when reading/writing from serializers. // diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index c52368b5514db..ffd23380a886f 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -85,14 +85,31 @@ class FakeTaskScheduler(sc: SparkContext, liveExecutors: (String, String)* /* ex val finishedManagers = new ArrayBuffer[TaskSetManager] val taskSetsFailed = new ArrayBuffer[String] - val executors = new mutable.HashMap[String, String] ++ liveExecutors + val executors = new mutable.HashMap[String, String] + for ((execId, host) <- liveExecutors) { + addExecutor(execId, host) + } + for ((execId, host) <- liveExecutors; rack <- getRackForHost(host)) { hostsByRack.getOrElseUpdate(rack, new mutable.HashSet[String]()) += host } dagScheduler = new FakeDAGScheduler(sc, this) - def removeExecutor(execId: String): Unit = executors -= execId + def removeExecutor(execId: String) { + executors -= execId + val host = executorIdToHost.get(execId) + assert(host != None) + val hostId = host.get + val executorsOnHost = executorsByHost(hostId) + executorsOnHost -= execId + for (rack <- getRackForHost(hostId); hosts <- hostsByRack.get(rack)) { + hosts -= hostId + if (hosts.isEmpty) { + hostsByRack -= rack + } + } + } override def taskSetFinished(manager: TaskSetManager): Unit = finishedManagers += manager @@ -100,8 +117,15 @@ class FakeTaskScheduler(sc: SparkContext, liveExecutors: (String, String)* /* ex override def hasExecutorsAliveOnHost(host: String): Boolean = executors.values.exists(_ == host) + override def hasHostAliveOnRack(rack: String): Boolean = { + hostsByRack.get(rack) != None + } + def addExecutor(execId: String, host: String) { executors.put(execId, host) + val executorsOnHost = executorsByHost.getOrElseUpdate(host, new mutable.HashSet[String]) + executorsOnHost += execId + executorIdToHost += execId -> host for (rack <- getRackForHost(host)) { hostsByRack.getOrElseUpdate(rack, new mutable.HashSet[String]()) += host } @@ -123,7 +147,7 @@ class LargeTask(stageId: Int) extends Task[Array[Byte]](stageId, 0) { } class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { - import TaskLocality.{ANY, PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL} + import TaskLocality.{ANY, PROCESS_LOCAL, NO_PREF, NODE_LOCAL, RACK_LOCAL} private val conf = new SparkConf @@ -134,18 +158,13 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { sc = new SparkContext("local", "test") val sched = new FakeTaskScheduler(sc, ("exec1", "host1")) val taskSet = FakeTask.createTaskSet(1) - val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES) + val clock = new FakeClock + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) - // Offer a host with process-local as the constraint; this should work because the TaskSet - // above won't have any locality preferences - val taskOption = manager.resourceOffer("exec1", "host1", TaskLocality.PROCESS_LOCAL) + // Offer a host with NO_PREF as the constraint, + // we should get a nopref task immediately since that's what we only have + var taskOption = manager.resourceOffer("exec1", "host1", NO_PREF) assert(taskOption.isDefined) - val task = taskOption.get - assert(task.executorId === "exec1") - assert(sched.startedTasks.contains(0)) - - // Re-offer the host -- now we should get no more tasks - assert(manager.resourceOffer("exec1", "host1", PROCESS_LOCAL) === None) // Tell it the task has finished manager.handleSuccessfulTask(0, createTaskResult(0)) @@ -161,7 +180,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { // First three offers should all find tasks for (i <- 0 until 3) { - val taskOption = manager.resourceOffer("exec1", "host1", PROCESS_LOCAL) + var taskOption = manager.resourceOffer("exec1", "host1", NO_PREF) assert(taskOption.isDefined) val task = taskOption.get assert(task.executorId === "exec1") @@ -169,7 +188,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { assert(sched.startedTasks.toSet === Set(0, 1, 2)) // Re-offer the host -- now we should get no more tasks - assert(manager.resourceOffer("exec1", "host1", PROCESS_LOCAL) === None) + assert(manager.resourceOffer("exec1", "host1", NO_PREF) === None) // Finish the first two tasks manager.handleSuccessfulTask(0, createTaskResult(0)) @@ -211,37 +230,40 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { ) val clock = new FakeClock val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) - // First offer host1, exec1: first task should be chosen assert(manager.resourceOffer("exec1", "host1", ANY).get.index === 0) - - // Offer host1, exec1 again: the last task, which has no prefs, should be chosen - assert(manager.resourceOffer("exec1", "host1", ANY).get.index === 3) - - // Offer host1, exec1 again, at PROCESS_LOCAL level: nothing should get chosen - assert(manager.resourceOffer("exec1", "host1", PROCESS_LOCAL) === None) + assert(manager.resourceOffer("exec1", "host1", PROCESS_LOCAL) == None) clock.advance(LOCALITY_WAIT) - - // Offer host1, exec1 again, at PROCESS_LOCAL level: nothing should get chosen - assert(manager.resourceOffer("exec1", "host1", PROCESS_LOCAL) === None) - - // Offer host1, exec1 again, at NODE_LOCAL level: we should choose task 2 + // Offer host1, exec1 again, at NODE_LOCAL level: the node local (task 2) should + // get chosen before the noPref task assert(manager.resourceOffer("exec1", "host1", NODE_LOCAL).get.index == 2) - // Offer host1, exec1 again, at NODE_LOCAL level: nothing should get chosen - assert(manager.resourceOffer("exec1", "host1", NODE_LOCAL) === None) - - // Offer host1, exec1 again, at ANY level: nothing should get chosen - assert(manager.resourceOffer("exec1", "host1", ANY) === None) + // Offer host2, exec3 again, at NODE_LOCAL level: we should choose task 2 + assert(manager.resourceOffer("exec2", "host2", NODE_LOCAL).get.index == 1) + // Offer host2, exec3 again, at NODE_LOCAL level: we should get noPref task + // after failing to find a node_Local task + assert(manager.resourceOffer("exec2", "host2", NODE_LOCAL) == None) clock.advance(LOCALITY_WAIT) + assert(manager.resourceOffer("exec2", "host2", NO_PREF).get.index == 3) + } - // Offer host1, exec1 again, at ANY level: task 1 should get chosen - assert(manager.resourceOffer("exec1", "host1", ANY).get.index === 1) - - // Offer host1, exec1 again, at ANY level: nothing should be chosen as we've launched all tasks - assert(manager.resourceOffer("exec1", "host1", ANY) === None) + test("we do not need to delay scheduling when we only have noPref tasks in the queue") { + sc = new SparkContext("local", "test") + val sched = new FakeTaskScheduler(sc, ("exec1", "host1"), ("exec3", "host2")) + val taskSet = FakeTask.createTaskSet(3, + Seq(TaskLocation("host1", "exec1")), + Seq(TaskLocation("host2", "exec3")), + Seq() // Last task has no locality prefs + ) + val clock = new FakeClock + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) + // First offer host1, exec1: first task should be chosen + assert(manager.resourceOffer("exec1", "host1", PROCESS_LOCAL).get.index === 0) + assert(manager.resourceOffer("exec3", "host2", PROCESS_LOCAL).get.index === 1) + assert(manager.resourceOffer("exec3", "host2", NODE_LOCAL) == None) + assert(manager.resourceOffer("exec3", "host2", NO_PREF).get.index === 2) } test("delay scheduling with fallback") { @@ -298,20 +320,24 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { // First offer host1: first task should be chosen assert(manager.resourceOffer("exec1", "host1", ANY).get.index === 0) - // Offer host1 again: third task should be chosen immediately because host3 is not up - assert(manager.resourceOffer("exec1", "host1", ANY).get.index === 2) - - // After this, nothing should get chosen + // After this, nothing should get chosen, because we have separated tasks with unavailable preference + // from the noPrefPendingTasks assert(manager.resourceOffer("exec1", "host1", ANY) === None) // Now mark host2 as dead sched.removeExecutor("exec2") manager.executorLost("exec2", "host2") - // Task 1 should immediately be launched on host1 because its original host is gone + // nothing should be chosen + assert(manager.resourceOffer("exec1", "host1", ANY) === None) + + clock.advance(LOCALITY_WAIT * 2) + + // task 1 and 2 would be scheduled as nonLocal task assert(manager.resourceOffer("exec1", "host1", ANY).get.index === 1) + assert(manager.resourceOffer("exec1", "host1", ANY).get.index === 2) - // Now that all tasks have launched, nothing new should be launched anywhere else + // all finished assert(manager.resourceOffer("exec1", "host1", ANY) === None) assert(manager.resourceOffer("exec2", "host2", ANY) === None) } @@ -373,7 +399,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { val manager = new TaskSetManager(sched, taskSet, 4, clock) { - val offerResult = manager.resourceOffer("exec1", "host1", TaskLocality.PROCESS_LOCAL) + val offerResult = manager.resourceOffer("exec1", "host1", PROCESS_LOCAL) assert(offerResult.isDefined, "Expect resource offer to return a task") assert(offerResult.get.index === 0) @@ -384,15 +410,15 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { assert(!sched.taskSetsFailed.contains(taskSet.id)) // Ensure scheduling on exec1 fails after failure 1 due to blacklist - assert(manager.resourceOffer("exec1", "host1", TaskLocality.PROCESS_LOCAL).isEmpty) - assert(manager.resourceOffer("exec1", "host1", TaskLocality.NODE_LOCAL).isEmpty) - assert(manager.resourceOffer("exec1", "host1", TaskLocality.RACK_LOCAL).isEmpty) - assert(manager.resourceOffer("exec1", "host1", TaskLocality.ANY).isEmpty) + assert(manager.resourceOffer("exec1", "host1", PROCESS_LOCAL).isEmpty) + assert(manager.resourceOffer("exec1", "host1", NODE_LOCAL).isEmpty) + assert(manager.resourceOffer("exec1", "host1", RACK_LOCAL).isEmpty) + assert(manager.resourceOffer("exec1", "host1", ANY).isEmpty) } // Run the task on exec1.1 - should work, and then fail it on exec1.1 { - val offerResult = manager.resourceOffer("exec1.1", "host1", TaskLocality.NODE_LOCAL) + val offerResult = manager.resourceOffer("exec1.1", "host1", NODE_LOCAL) assert(offerResult.isDefined, "Expect resource offer to return a task for exec1.1, offerResult = " + offerResult) @@ -404,12 +430,12 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { assert(!sched.taskSetsFailed.contains(taskSet.id)) // Ensure scheduling on exec1.1 fails after failure 2 due to blacklist - assert(manager.resourceOffer("exec1.1", "host1", TaskLocality.NODE_LOCAL).isEmpty) + assert(manager.resourceOffer("exec1.1", "host1", NODE_LOCAL).isEmpty) } // Run the task on exec2 - should work, and then fail it on exec2 { - val offerResult = manager.resourceOffer("exec2", "host2", TaskLocality.ANY) + val offerResult = manager.resourceOffer("exec2", "host2", ANY) assert(offerResult.isDefined, "Expect resource offer to return a task") assert(offerResult.get.index === 0) @@ -420,20 +446,20 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { assert(!sched.taskSetsFailed.contains(taskSet.id)) // Ensure scheduling on exec2 fails after failure 3 due to blacklist - assert(manager.resourceOffer("exec2", "host2", TaskLocality.ANY).isEmpty) + assert(manager.resourceOffer("exec2", "host2", ANY).isEmpty) } // After reschedule delay, scheduling on exec1 should be possible. clock.advance(rescheduleDelay) { - val offerResult = manager.resourceOffer("exec1", "host1", TaskLocality.PROCESS_LOCAL) + val offerResult = manager.resourceOffer("exec1", "host1", PROCESS_LOCAL) assert(offerResult.isDefined, "Expect resource offer to return a task") assert(offerResult.get.index === 0) assert(offerResult.get.executorId === "exec1") - assert(manager.resourceOffer("exec1", "host1", TaskLocality.PROCESS_LOCAL).isEmpty) + assert(manager.resourceOffer("exec1", "host1", PROCESS_LOCAL).isEmpty) // Cause exec1 to fail : failure 4 manager.handleFailedTask(offerResult.get.taskId, TaskState.FINISHED, TaskResultLost) @@ -443,7 +469,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { assert(sched.taskSetsFailed.contains(taskSet.id)) } - test("new executors get added") { + test("new executors get added and lost") { // Assign host2 to rack2 FakeRackUtil.cleanUp() FakeRackUtil.assignHostToRack("host2", "rack2") @@ -456,26 +482,25 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { Seq()) val clock = new FakeClock val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) - // All tasks added to no-pref list since no preferred location is available - assert(manager.pendingTasksWithNoPrefs.size === 4) // Only ANY is valid - assert(manager.myLocalityLevels.sameElements(Array(ANY))) + assert(manager.myLocalityLevels.sameElements(Array(NO_PREF, ANY))) // Add a new executor sched.addExecutor("execD", "host1") manager.executorAdded() - // Task 0 and 1 should be removed from no-pref list - assert(manager.pendingTasksWithNoPrefs.size === 2) // Valid locality should contain NODE_LOCAL and ANY - assert(manager.myLocalityLevels.sameElements(Array(NODE_LOCAL, ANY))) + assert(manager.myLocalityLevels.sameElements(Array(NODE_LOCAL, NO_PREF, ANY))) // Add another executor sched.addExecutor("execC", "host2") manager.executorAdded() - // No-pref list now only contains task 3 - assert(manager.pendingTasksWithNoPrefs.size === 1) // Valid locality should contain PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL and ANY - assert(manager.myLocalityLevels.sameElements( - Array(PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY))) - FakeRackUtil.cleanUp() + assert(manager.myLocalityLevels.sameElements(Array(PROCESS_LOCAL, NODE_LOCAL, NO_PREF, RACK_LOCAL, ANY))) + // test if the valid locality is recomputed when the executor is lost + sched.removeExecutor("execC") + manager.executorLost("execC", "host2") + assert(manager.myLocalityLevels.sameElements(Array(NODE_LOCAL, NO_PREF, ANY))) + sched.removeExecutor("execD") + manager.executorLost("execD", "host1") + assert(manager.myLocalityLevels.sameElements(Array(NO_PREF, ANY))) } test("test RACK_LOCAL tasks") { @@ -506,7 +531,6 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { // Offer host2 // Task 1 can be scheduled with RACK_LOCAL assert(manager.resourceOffer("execB", "host2", RACK_LOCAL).get.index === 1) - FakeRackUtil.cleanUp() } test("do not emit warning when serialized task is small") { @@ -536,6 +560,53 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { assert(manager.emittedTaskSizeWarning) } + test("speculative and noPref task should be scheduled after node-local") { + sc = new SparkContext("local", "test") + val sched = new FakeTaskScheduler(sc, ("execA", "host1"), ("execB", "host2"), ("execC", "host3")) + val taskSet = FakeTask.createTaskSet(4, + Seq(TaskLocation("host1", "execA")), + Seq(TaskLocation("host2"), TaskLocation("host1")), + Seq(), + Seq(TaskLocation("host3", "execC"))) + val clock = new FakeClock + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) + + assert(manager.resourceOffer("execA", "host1", PROCESS_LOCAL).get.index === 0) + assert(manager.resourceOffer("execA", "host1", NODE_LOCAL) == None) + assert(manager.resourceOffer("execA", "host1", NO_PREF).get.index == 1) + + manager.speculatableTasks += 1 + clock.advance(LOCALITY_WAIT) + // schedule the nonPref task + assert(manager.resourceOffer("execA", "host1", NO_PREF).get.index === 2) + // schedule the speculative task + assert(manager.resourceOffer("execB", "host2", NO_PREF).get.index === 1) + clock.advance(LOCALITY_WAIT * 3) + // schedule non-local tasks + assert(manager.resourceOffer("execB", "host2", ANY).get.index === 3) + } + + test("node-local tasks should be scheduled right away when there are only node-local and no-preference tasks") { + sc = new SparkContext("local", "test") + val sched = new FakeTaskScheduler(sc, ("execA", "host1"), ("execB", "host2"), ("execC", "host3")) + val taskSet = FakeTask.createTaskSet(4, + Seq(TaskLocation("host1")), + Seq(TaskLocation("host2")), + Seq(), + Seq(TaskLocation("host3"))) + val clock = new FakeClock + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) + + // node-local tasks are scheduled without delay + assert(manager.resourceOffer("execA", "host1", NODE_LOCAL).get.index === 0) + assert(manager.resourceOffer("execA", "host2", NODE_LOCAL).get.index === 1) + assert(manager.resourceOffer("execA", "host3", NODE_LOCAL).get.index === 3) + assert(manager.resourceOffer("execA", "host3", NODE_LOCAL) === None) + + // schedule no-preference after node local ones + assert(manager.resourceOffer("execA", "host3", NO_PREF).get.index === 2) + } + def createTaskResult(id: Int): DirectTaskResult[Int] = { val valueSer = SparkEnv.get.serializer.newInstance() new DirectTaskResult[Int](valueSer.serialize(id), mutable.Map.empty, new TaskMetrics) diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 1ee936bc78f49..70d423ba8a04d 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.util import scala.util.Random import java.io.{File, ByteArrayOutputStream, ByteArrayInputStream, FileOutputStream} -import java.net.URI +import java.net.{BindException, ServerSocket, URI} import java.nio.{ByteBuffer, ByteOrder} import com.google.common.base.Charsets @@ -265,4 +265,36 @@ class UtilsSuite extends FunSuite { Array("hdfs:/a.jar", "s3:/another.jar")) } + test("isBindCollision") { + // Negatives + assert(!Utils.isBindCollision(null)) + assert(!Utils.isBindCollision(new Exception)) + assert(!Utils.isBindCollision(new Exception(new Exception))) + assert(!Utils.isBindCollision(new Exception(new BindException))) + assert(!Utils.isBindCollision(new Exception(new BindException("Random message")))) + + // Positives + val be = new BindException("Address already in use") + val be1 = new Exception(new BindException("Address already in use")) + val be2 = new Exception(new Exception(new BindException("Address already in use"))) + assert(Utils.isBindCollision(be)) + assert(Utils.isBindCollision(be1)) + assert(Utils.isBindCollision(be2)) + + // Actual bind exception + var server1: ServerSocket = null + var server2: ServerSocket = null + try { + server1 = new java.net.ServerSocket(0) + server2 = new java.net.ServerSocket(server1.getLocalPort) + } catch { + case e: Exception => + assert(e.isInstanceOf[java.net.BindException]) + assert(Utils.isBindCollision(e)) + } finally { + Option(server1).foreach(_.close()) + Option(server2).foreach(_.close()) + } + } + } diff --git a/docs/configuration.md b/docs/configuration.md index 5e7556c08ee36..5e3eb0f0871af 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -266,7 +266,7 @@ Apart from these, the following properties are also available, and may be useful spark.shuffle.file.buffer.kb - 100 + 32 Size of the in-memory buffer for each shuffle file output stream, in kilobytes. These buffers reduce the number of disk seeks and system calls made in creating intermediate shuffle files. @@ -566,6 +566,7 @@ Apart from these, the following properties are also available, and may be useful (local hostname) Hostname or IP address for the driver to listen on. + This is used for communicating with the executors and the standalone Master. @@ -573,6 +574,51 @@ Apart from these, the following properties are also available, and may be useful (random) Port for the driver to listen on. + This is used for communicating with the executors and the standalone Master. + + + + spark.fileserver.port + (random) + + Port for the driver's HTTP file server to listen on. + + + + spark.broadcast.port + (random) + + Port for the driver's HTTP broadcast server to listen on. + This is not relevant for torrent broadcast. + + + + spark.replClassServer.port + (random) + + Port for the driver's HTTP class server to listen on. + This is only relevant for the Spark shell. + + + + spark.blockManager.port + (random) + + Port for all block managers to listen on. These exist on both the driver and the executors. + + + + spark.executor.port + (random) + + Port for the executor to listen on. This is used for communicating with the driver. + + + + spark.port.maxRetries + 16 + + Maximum number of retries when binding to a port before giving up. diff --git a/docs/security.md b/docs/security.md index 8312f8d017e1f..ec0523184d665 100644 --- a/docs/security.md +++ b/docs/security.md @@ -7,6 +7,9 @@ Spark currently supports authentication via a shared secret. Authentication can * For Spark on [YARN](running-on-yarn.html) deployments, configuring `spark.authenticate` to `true` will automatically handle generating and distributing the shared secret. Each application will use a unique shared secret. * For other types of Spark deployments, the Spark parameter `spark.authenticate.secret` should be configured on each of the nodes. This secret will be used by all the Master/Workers and applications. +* **IMPORTANT NOTE:** *The experimental Netty shuffle path (`spark.shuffle.use.netty`) is not secured, so do not use Netty for shuffles if running with authentication.* + +## Web UI The Spark UI can also be secured by using [javax servlet filters](http://docs.oracle.com/javaee/6/api/javax/servlet/Filter.html) via the `spark.ui.filters` setting. A user may want to secure the UI if it has data that other users should not be allowed to see. The javax servlet filter specified by the user can authenticate the user and then once the user is logged in, Spark can compare that user versus the view ACLs to make sure they are authorized to view the UI. The configs `spark.acls.enable` and `spark.ui.view.acls` control the behavior of the ACLs. Note that the user who started the application always has view access to the UI. On YARN, the Spark UI uses the standard YARN web application proxy mechanism and will authenticate via any installed Hadoop filters. @@ -14,10 +17,132 @@ Spark also supports modify ACLs to control who has access to modify a running Sp Spark allows for a set of administrators to be specified in the acls who always have view and modify permissions to all the applications. is controlled by the config `spark.admin.acls`. This is useful on a shared cluster where you might have administrators or support staff who help users debug applications. +## Event Logging + If your applications are using event logging, the directory where the event logs go (`spark.eventLog.dir`) should be manually created and have the proper permissions set on it. If you want those log files secured, the permissions should be set to `drwxrwxrwxt` for that directory. The owner of the directory should be the super user who is running the history server and the group permissions should be restricted to super user group. This will allow all users to write to the directory but will prevent unprivileged users from removing or renaming a file unless they own the file or directory. The event log files will be created by Spark with permissions such that only the user and group have read and write access. -**IMPORTANT NOTE:** *The experimental Netty shuffle path (`spark.shuffle.use.netty`) is not secured, so do not use Netty for shuffles if running with authentication.* +## Configuring Ports for Network Security + +Spark makes heavy use of the network, and some environments have strict requirements for using tight +firewall settings. Below are the primary ports that Spark uses for its communication and how to +configure those ports. + +### Standalone mode only + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FromToDefault PortPurposeConfiguration + SettingNotes
BrowserStandalone Master8080Web UIspark.master.ui.port /
SPARK_MASTER_WEBUI_PORT
Jetty-based. Standalone mode only.
BrowserStandalone Worker8081Web UIspark.worker.ui.port /
SPARK_WORKER_WEBUI_PORT
Jetty-based. Standalone mode only.
Driver /
Standalone Worker
Standalone Master7077Submit job to cluster /
Join cluster
SPARK_MASTER_PORTAkka-based. Set to "0" to choose a port randomly. Standalone mode only.
Standalone MasterStandalone Worker(random)Schedule executorsSPARK_WORKER_PORTAkka-based. Set to "0" to choose a port randomly. Standalone mode only.
+ +### All cluster managers + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
FromToDefault PortPurposeConfiguration + SettingNotes
BrowserApplication4040Web UIspark.ui.portJetty-based
BrowserHistory Server18080Web UIspark.history.ui.portJetty-based
Executor /
Standalone Master
Driver(random)Connect to application /
Notify executor state changes
spark.driver.portAkka-based. Set to "0" to choose a port randomly.
DriverExecutor(random)Schedule tasksspark.executor.portAkka-based. Set to "0" to choose a port randomly.
ExecutorDriver(random)File server for files and jarsspark.fileserver.portJetty-based
ExecutorDriver(random)HTTP Broadcastspark.broadcast.portJetty-based. Not used by TorrentBroadcast, which sends data through the block manager + instead.
ExecutorDriver(random)Class file serverspark.replClassServer.portJetty-based. Only used in Spark shells.
Executor / DriverExecutor / Driver(random)Block Manager portspark.blockManager.portRaw socket via ServerSocketChannel
-See the [configuration page](configuration.html) for more details on the security configuration parameters. -See org.apache.spark.SecurityManager for implementation details about security. +See the [configuration page](configuration.html) for more details on the security configuration +parameters, and +org.apache.spark.SecurityManager for implementation details about security. diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md index 293a7ac9bc9aa..c791c81f8bfd0 100644 --- a/docs/spark-standalone.md +++ b/docs/spark-standalone.md @@ -299,97 +299,15 @@ You can run Spark alongside your existing Hadoop cluster by just launching it as # Configuring Ports for Network Security -Spark makes heavy use of the network, and some environments have strict requirements for using tight -firewall settings. Below are the primary ports that Spark uses for its communication and how to -configure those ports. - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FromToDefault PortPurposeConfiguration - SettingNotes
BrowserStandalone Cluster Master8080Web UIspark.master.ui.portJetty-based
BrowserDriver4040Web UIspark.ui.portJetty-based
BrowserHistory Server18080Web UIspark.history.ui.portJetty-based
BrowserWorker8081Web UIspark.worker.ui.portJetty-based
ApplicationStandalone Cluster Master7077Submit job to clusterspark.driver.portAkka-based. Set to "0" to choose a port randomly
WorkerStandalone Cluster Master7077Join clusterspark.driver.portAkka-based. Set to "0" to choose a port randomly
ApplicationWorker(random)Join clusterSPARK_WORKER_PORT (standalone cluster)Akka-based
Driver and other WorkersWorker(random) -
    -
  • File server for file and jars
  • -
  • Http Broadcast
  • -
  • Class file server (Spark Shell only)
  • -
-
NoneJetty-based. Each of these services starts on a random port that cannot be configured
+Spark makes heavy use of the network, and some environments have strict requirements for using +tight firewall settings. For a complete list of ports to configure, see the +[security page](security.html#configuring-ports-for-network-security). # High Availability By default, standalone scheduling clusters are resilient to Worker failures (insofar as Spark itself is resilient to losing work by moving it to other workers). However, the scheduler uses a Master to make scheduling decisions, and this (by default) creates a single point of failure: if the Master crashes, no new applications can be created. In order to circumvent this, we have two high availability schemes, detailed below. -## Standby Masters with ZooKeeper +# Standby Masters with ZooKeeper **Overview** @@ -429,7 +347,7 @@ There's an important distinction to be made between "registering with a Master" Due to this property, new Masters can be created at any time, and the only thing you need to worry about is that _new_ applications and Workers can find it to register with in case it becomes the leader. Once registered, you're taken care of. -## Single-Node Recovery with Local File System +# Single-Node Recovery with Local File System **Overview** diff --git a/external/kafka/pom.xml b/external/kafka/pom.xml index 2aee99949223a..4e2275ab238f7 100644 --- a/external/kafka/pom.xml +++ b/external/kafka/pom.xml @@ -68,6 +68,10 @@ org.slf4j slf4j-simple + + org.apache.zookeeper + zookeeper +
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 1d5d3762ed8e9..fd0b9556c7d54 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -271,6 +271,7 @@ class PythonMLLibAPI extends Serializable { .setNumIterations(numIterations) .setRegParam(regParam) .setStepSize(stepSize) + .setMiniBatchFraction(miniBatchFraction) if (regType == "l2") { lrAlg.optimizer.setUpdater(new SquaredL2Updater) } else if (regType == "l1") { @@ -341,16 +342,27 @@ class PythonMLLibAPI extends Serializable { stepSize: Double, regParam: Double, miniBatchFraction: Double, - initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { + initialWeightsBA: Array[Byte], + regType: String, + intercept: Boolean): java.util.List[java.lang.Object] = { + val SVMAlg = new SVMWithSGD() + SVMAlg.setIntercept(intercept) + SVMAlg.optimizer + .setNumIterations(numIterations) + .setRegParam(regParam) + .setStepSize(stepSize) + .setMiniBatchFraction(miniBatchFraction) + if (regType == "l2") { + SVMAlg.optimizer.setUpdater(new SquaredL2Updater) + } else if (regType == "l1") { + SVMAlg.optimizer.setUpdater(new L1Updater) + } else if (regType != "none") { + throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter." + + " Can only be initialized using the following string values: [l1, l2, none].") + } trainRegressionModel( (data, initialWeights) => - SVMWithSGD.train( - data, - numIterations, - stepSize, - regParam, - miniBatchFraction, - initialWeights), + SVMAlg.run(data, initialWeights), dataBytesJRDD, initialWeightsBA) } @@ -363,15 +375,28 @@ class PythonMLLibAPI extends Serializable { numIterations: Int, stepSize: Double, miniBatchFraction: Double, - initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { + initialWeightsBA: Array[Byte], + regParam: Double, + regType: String, + intercept: Boolean): java.util.List[java.lang.Object] = { + val LogRegAlg = new LogisticRegressionWithSGD() + LogRegAlg.setIntercept(intercept) + LogRegAlg.optimizer + .setNumIterations(numIterations) + .setRegParam(regParam) + .setStepSize(stepSize) + .setMiniBatchFraction(miniBatchFraction) + if (regType == "l2") { + LogRegAlg.optimizer.setUpdater(new SquaredL2Updater) + } else if (regType == "l1") { + LogRegAlg.optimizer.setUpdater(new L1Updater) + } else if (regType != "none") { + throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter." + + " Can only be initialized using the following string values: [l1, l2, none].") + } trainRegressionModel( (data, initialWeights) => - LogisticRegressionWithSGD.train( - data, - numIterations, - stepSize, - miniBatchFraction, - initialWeights), + LogRegAlg.run(data, initialWeights), dataBytesJRDD, initialWeightsBA) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index db425d866bbad..fce8fe29f6e40 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -52,13 +52,13 @@ class KMeans private ( def this() = this(2, 20, 1, KMeans.K_MEANS_PARALLEL, 5, 1e-4) /** Set the number of clusters to create (k). Default: 2. */ - def setK(k: Int): KMeans = { + def setK(k: Int): this.type = { this.k = k this } /** Set maximum number of iterations to run. Default: 20. */ - def setMaxIterations(maxIterations: Int): KMeans = { + def setMaxIterations(maxIterations: Int): this.type = { this.maxIterations = maxIterations this } @@ -68,7 +68,7 @@ class KMeans private ( * initial cluster centers, or "k-means||" to use a parallel variant of k-means++ * (Bahmani et al., Scalable K-Means++, VLDB 2012). Default: k-means||. */ - def setInitializationMode(initializationMode: String): KMeans = { + def setInitializationMode(initializationMode: String): this.type = { if (initializationMode != KMeans.RANDOM && initializationMode != KMeans.K_MEANS_PARALLEL) { throw new IllegalArgumentException("Invalid initialization mode: " + initializationMode) } @@ -83,7 +83,7 @@ class KMeans private ( * return the best clustering found over any run. Default: 1. */ @Experimental - def setRuns(runs: Int): KMeans = { + def setRuns(runs: Int): this.type = { if (runs <= 0) { throw new IllegalArgumentException("Number of runs must be positive") } @@ -95,7 +95,7 @@ class KMeans private ( * Set the number of steps for the k-means|| initialization mode. This is an advanced * setting -- the default of 5 is almost always enough. Default: 5. */ - def setInitializationSteps(initializationSteps: Int): KMeans = { + def setInitializationSteps(initializationSteps: Int): this.type = { if (initializationSteps <= 0) { throw new IllegalArgumentException("Number of initialization steps must be positive") } @@ -107,7 +107,7 @@ class KMeans private ( * Set the distance threshold within which we've consider centers to have converged. * If all centers move less than this Euclidean distance, we stop iterating one run. */ - def setEpsilon(epsilon: Double): KMeans = { + def setEpsilon(epsilon: Double): this.type = { this.epsilon = epsilon this } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 87c81e7b0bd2f..3bf44ad7c44e3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -19,16 +19,17 @@ package org.apache.spark.mllib.feature import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import scala.util.Random import com.github.fommil.netlib.BLAS.{getInstance => blas} -import org.apache.spark.{HashPartitioner, Logging} + +import org.apache.spark.Logging import org.apache.spark.SparkContext._ import org.apache.spark.annotation.Experimental import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.rdd.RDDFunctions._ import org.apache.spark.rdd._ -import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.Utils +import org.apache.spark.util.random.XORShiftRandom /** * Entry in vocabulary @@ -58,29 +59,63 @@ private case class VocabWord( * Efficient Estimation of Word Representations in Vector Space * and * Distributed Representations of Words and Phrases and their Compositionality. - * @param size vector dimension - * @param startingAlpha initial learning rate - * @param parallelism number of partitions to run Word2Vec (using a small number for accuracy) - * @param numIterations number of iterations to run, should be smaller than or equal to parallelism */ @Experimental -class Word2Vec( - val size: Int, - val startingAlpha: Double, - val parallelism: Int, - val numIterations: Int) extends Serializable with Logging { +class Word2Vec extends Serializable with Logging { + + private var vectorSize = 100 + private var startingAlpha = 0.025 + private var numPartitions = 1 + private var numIterations = 1 + private var seed = Utils.random.nextLong() + + /** + * Sets vector size (default: 100). + */ + def setVectorSize(vectorSize: Int): this.type = { + this.vectorSize = vectorSize + this + } + + /** + * Sets initial learning rate (default: 0.025). + */ + def setLearningRate(learningRate: Double): this.type = { + this.startingAlpha = learningRate + this + } /** - * Word2Vec with a single thread. + * Sets number of partitions (default: 1). Use a small number for accuracy. */ - def this(size: Int, startingAlpha: Int) = this(size, startingAlpha, 1, 1) + def setNumPartitions(numPartitions: Int): this.type = { + require(numPartitions > 0, s"numPartitions must be greater than 0 but got $numPartitions") + this.numPartitions = numPartitions + this + } + + /** + * Sets number of iterations (default: 1), which should be smaller than or equal to number of + * partitions. + */ + def setNumIterations(numIterations: Int): this.type = { + this.numIterations = numIterations + this + } + + /** + * Sets random seed (default: a random long integer). + */ + def setSeed(seed: Long): this.type = { + this.seed = seed + this + } private val EXP_TABLE_SIZE = 1000 private val MAX_EXP = 6 private val MAX_CODE_LENGTH = 40 private val MAX_SENTENCE_LENGTH = 1000 - private val layer1Size = size - private val modelPartitionNum = 100 + private val layer1Size = vectorSize /** context words from [-window, window] */ private val window = 5 @@ -94,12 +129,12 @@ class Word2Vec( private var vocabHash = mutable.HashMap.empty[String, Int] private var alpha = startingAlpha - private def learnVocab(words:RDD[String]): Unit = { + private def learnVocab(words: RDD[String]): Unit = { vocab = words.map(w => (w, 1)) .reduceByKey(_ + _) .map(x => VocabWord( - x._1, - x._2, + x._1, + x._2, new Array[Int](MAX_CODE_LENGTH), new Array[Int](MAX_CODE_LENGTH), 0)) @@ -245,23 +280,24 @@ class Word2Vec( } } - val newSentences = sentences.repartition(parallelism).cache() + val newSentences = sentences.repartition(numPartitions).cache() + val initRandom = new XORShiftRandom(seed) var syn0Global = - Array.fill[Float](vocabSize * layer1Size)((Random.nextFloat() - 0.5f) / layer1Size) + Array.fill[Float](vocabSize * layer1Size)((initRandom.nextFloat() - 0.5f) / layer1Size) var syn1Global = new Array[Float](vocabSize * layer1Size) - - for(iter <- 1 to numIterations) { - val (aggSyn0, aggSyn1, _, _) = - // TODO: broadcast temp instead of serializing it directly - // or initialize the model in each executor - newSentences.treeAggregate((syn0Global, syn1Global, 0, 0))( - seqOp = (c, v) => (c, v) match { + + for (k <- 1 to numIterations) { + val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) => + val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8)) + val model = iter.foldLeft((syn0Global, syn1Global, 0, 0)) { case ((syn0, syn1, lastWordCount, wordCount), sentence) => var lwc = lastWordCount - var wc = wordCount + var wc = wordCount if (wordCount - lastWordCount > 10000) { lwc = wordCount - alpha = startingAlpha * (1 - parallelism * wordCount.toDouble / (trainWordsCount + 1)) + // TODO: discount by iteration? + alpha = + startingAlpha * (1 - numPartitions * wordCount.toDouble / (trainWordsCount + 1)) if (alpha < startingAlpha * 0.0001) alpha = startingAlpha * 0.0001 logInfo("wordCount = " + wordCount + ", alpha = " + alpha) } @@ -269,8 +305,7 @@ class Word2Vec( var pos = 0 while (pos < sentence.size) { val word = sentence(pos) - // TODO: fix random seed - val b = Random.nextInt(window) + val b = random.nextInt(window) // Train Skip-gram var a = b while (a < window * 2 + 1 - b) { @@ -280,7 +315,7 @@ class Word2Vec( val lastWord = sentence(c) val l1 = lastWord * layer1Size val neu1e = new Array[Float](layer1Size) - // Hierarchical softmax + // Hierarchical softmax var d = 0 while (d < bcVocab.value(word).codeLen) { val l2 = bcVocab.value(word).point(d) * layer1Size @@ -303,44 +338,44 @@ class Word2Vec( pos += 1 } (syn0, syn1, lwc, wc) - }, - combOp = (c1, c2) => (c1, c2) match { - case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) => - val n = syn0_1.length - val weight1 = 1.0f * wc_1 / (wc_1 + wc_2) - val weight2 = 1.0f * wc_2 / (wc_1 + wc_2) - blas.sscal(n, weight1, syn0_1, 1) - blas.sscal(n, weight1, syn1_1, 1) - blas.saxpy(n, weight2, syn0_2, 1, syn0_1, 1) - blas.saxpy(n, weight2, syn1_2, 1, syn1_1, 1) - (syn0_1, syn1_1, lwc_1 + lwc_2, wc_1 + wc_2) - }) + } + Iterator(model) + } + val (aggSyn0, aggSyn1, _, _) = + partial.treeReduce { case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) => + val n = syn0_1.length + val weight1 = 1.0f * wc_1 / (wc_1 + wc_2) + val weight2 = 1.0f * wc_2 / (wc_1 + wc_2) + blas.sscal(n, weight1, syn0_1, 1) + blas.sscal(n, weight1, syn1_1, 1) + blas.saxpy(n, weight2, syn0_2, 1, syn0_1, 1) + blas.saxpy(n, weight2, syn1_2, 1, syn1_1, 1) + (syn0_1, syn1_1, lwc_1 + lwc_2, wc_1 + wc_2) + } syn0Global = aggSyn0 syn1Global = aggSyn1 } newSentences.unpersist() - val wordMap = new Array[(String, Array[Float])](vocabSize) + val word2VecMap = mutable.HashMap.empty[String, Array[Float]] var i = 0 while (i < vocabSize) { val word = bcVocab.value(i).word val vector = new Array[Float](layer1Size) Array.copy(syn0Global, i * layer1Size, vector, 0, layer1Size) - wordMap(i) = (word, vector) + word2VecMap += word -> vector i += 1 } - val modelRDD = sc.parallelize(wordMap, modelPartitionNum) - .partitionBy(new HashPartitioner(modelPartitionNum)) - .persist(StorageLevel.MEMORY_AND_DISK) - - new Word2VecModel(modelRDD) + + new Word2VecModel(word2VecMap.toMap) } } /** * Word2Vec model -*/ -class Word2VecModel(private val model: RDD[(String, Array[Float])]) extends Serializable { + */ +class Word2VecModel private[mllib] ( + private val model: Map[String, Array[Float]]) extends Serializable { private def cosineSimilarity(v1: Array[Float], v2: Array[Float]): Double = { require(v1.length == v2.length, "Vectors should have the same length") @@ -357,11 +392,12 @@ class Word2VecModel(private val model: RDD[(String, Array[Float])]) extends Seri * @return vector representation of word */ def transform(word: String): Vector = { - val result = model.lookup(word) - if (result.isEmpty) { - throw new IllegalStateException(s"$word not in vocabulary") + model.get(word) match { + case Some(vec) => + Vectors.dense(vec.map(_.toDouble)) + case None => + throw new IllegalStateException(s"$word not in vocabulary") } - else Vectors.dense(result(0).map(_.toDouble)) } /** @@ -392,33 +428,13 @@ class Word2VecModel(private val model: RDD[(String, Array[Float])]) extends Seri */ def findSynonyms(vector: Vector, num: Int): Array[(String, Double)] = { require(num > 0, "Number of similar words should > 0") - val topK = model.map { case(w, vec) => - (cosineSimilarity(vector.toArray.map(_.toFloat), vec), w) } - .sortByKey(ascending = false) - .take(num + 1) - .map(_.swap) - .tail - - topK - } -} - -object Word2Vec{ - /** - * Train Word2Vec model - * @param input RDD of words - * @param size vector dimension - * @param startingAlpha initial learning rate - * @param parallelism number of partitions to run Word2Vec (using a small number for accuracy) - * @param numIterations number of iterations, should be smaller than or equal to parallelism - * @return Word2Vec model - */ - def train[S <: Iterable[String]]( - input: RDD[S], - size: Int, - startingAlpha: Double, - parallelism: Int = 1, - numIterations:Int = 1): Word2VecModel = { - new Word2Vec(size,startingAlpha, parallelism, numIterations).fit[S](input) + // TODO: optimize top-k + val fVector = vector.toArray.map(_.toFloat) + model.mapValues(vec => cosineSimilarity(fVector, vec)) + .toSeq + .sortBy(- _._2) + .take(num + 1) + .tail + .toArray } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala index b5db39b68a223..e34335d89eb75 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala @@ -30,29 +30,22 @@ class Word2VecSuite extends FunSuite with LocalSparkContext { val localDoc = Seq(sentence, sentence) val doc = sc.parallelize(localDoc) .map(line => line.split(" ").toSeq) - val size = 10 - val startingAlpha = 0.025 - val window = 2 - val minCount = 2 - val num = 2 - - val model = Word2Vec.train(doc, size, startingAlpha) + val model = new Word2Vec().setVectorSize(10).setSeed(42L).fit(doc) val syms = model.findSynonyms("a", 2) - assert(syms.length == num) + assert(syms.length == 2) assert(syms(0)._1 == "b") assert(syms(1)._1 == "c") } - test("Word2VecModel") { val num = 2 - val localModel = Seq( + val word2VecMap = Map( ("china", Array(0.50f, 0.50f, 0.50f, 0.50f)), ("japan", Array(0.40f, 0.50f, 0.50f, 0.50f)), ("taiwan", Array(0.60f, 0.50f, 0.50f, 0.50f)), ("korea", Array(0.45f, 0.60f, 0.60f, 0.60f)) ) - val model = new Word2VecModel(sc.parallelize(localModel, 2)) + val model = new Word2VecModel(word2VecMap) val syms = model.findSynonyms("china", num) assert(syms.length == num) assert(syms(0)._1 == "taiwan") diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index aac621fe53938..40b588512ff08 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -330,6 +330,8 @@ object TestSettings { fork := true, javaOptions in Test += "-Dspark.test.home=" + sparkHome, javaOptions in Test += "-Dspark.testing=1", + javaOptions in Test += "-Dspark.ports.maxRetries=100", + javaOptions in Test += "-Dspark.ui.port=0", javaOptions in Test += "-Dsun.io.serialization.extendedDebugInfo=true", javaOptions in Test ++= System.getProperties.filter(_._1 startsWith "spark") .map { case (k,v) => s"-D$k=$v" }.toSeq, diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index a85abbcd02c79..ffdda7ee19302 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -76,11 +76,36 @@ def predict(self, x): class LogisticRegressionWithSGD(object): @classmethod - def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, initialWeights=None): - """Train a logistic regression model on the given data.""" + def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, + initialWeights=None, regParam=1.0, regType=None, intercept=False): + """ + Train a logistic regression model on the given data. + + @param data: The training data. + @param iterations: The number of iterations (default: 100). + @param step: The step parameter used in SGD + (default: 1.0). + @param miniBatchFraction: Fraction of data to be used for each SGD + iteration. + @param initialWeights: The initial weights (default: None). + @param regParam: The regularizer parameter (default: 1.0). + @param regType: The type of regularizer used for training + our model. + Allowed values: "l1" for using L1Updater, + "l2" for using + SquaredL2Updater, + "none" for no regularizer. + (default: "none") + @param intercept: Boolean parameter which indicates the use + or not of the augmented representation for + training data (i.e. whether bias features + are activated or not). + """ sc = data.context + if regType is None: + regType = "none" train_func = lambda d, i: sc._jvm.PythonMLLibAPI().trainLogisticRegressionModelWithSGD( - d._jrdd, iterations, step, miniBatchFraction, i) + d._jrdd, iterations, step, miniBatchFraction, i, regParam, regType, intercept) return _regression_train_wrapper(sc, train_func, LogisticRegressionModel, data, initialWeights) @@ -121,11 +146,35 @@ class SVMWithSGD(object): @classmethod def train(cls, data, iterations=100, step=1.0, regParam=1.0, - miniBatchFraction=1.0, initialWeights=None): - """Train a support vector machine on the given data.""" + miniBatchFraction=1.0, initialWeights=None, regType=None, intercept=False): + """ + Train a support vector machine on the given data. + + @param data: The training data. + @param iterations: The number of iterations (default: 100). + @param step: The step parameter used in SGD + (default: 1.0). + @param regParam: The regularizer parameter (default: 1.0). + @param miniBatchFraction: Fraction of data to be used for each SGD + iteration. + @param initialWeights: The initial weights (default: None). + @param regType: The type of regularizer used for training + our model. + Allowed values: "l1" for using L1Updater, + "l2" for using + SquaredL2Updater, + "none" for no regularizer. + (default: "none") + @param intercept: Boolean parameter which indicates the use + or not of the augmented representation for + training data (i.e. whether bias features + are activated or not). + """ sc = data.context + if regType is None: + regType = "none" train_func = lambda d, i: sc._jvm.PythonMLLibAPI().trainSVMModelWithSGD( - d._jrdd, iterations, step, regParam, miniBatchFraction, i) + d._jrdd, iterations, step, regParam, miniBatchFraction, i, regType, intercept) return _regression_train_wrapper(sc, train_func, SVMModel, data, initialWeights) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 293af6183e9cf..cc72c7ca17bb8 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -690,12 +690,12 @@ def _infer_schema_type(obj, dataType): ByteType: (int, long), ShortType: (int, long), IntegerType: (int, long), - LongType: (int, long), + LongType: (long,), FloatType: (float,), DoubleType: (float,), DecimalType: (decimal.Decimal,), StringType: (str, unicode), - TimestampType: (datetime.datetime, datetime.time, datetime.date), + TimestampType: (datetime.datetime,), ArrayType: (list, tuple, array), MapType: (dict,), StructType: (tuple, list), @@ -1063,12 +1063,15 @@ def applySchema(self, rdd, schema): [Row(field1=1, field2=u'row1'),..., Row(field1=3, field2=u'row3')] >>> from datetime import datetime - >>> rdd = sc.parallelize([(127, -32768, 1.0, + >>> rdd = sc.parallelize([(127, -128L, -32768, 32767, 2147483647L, 1.0, ... datetime(2010, 1, 1, 1, 1, 1), ... {"a": 1}, (2,), [1, 2, 3], None)]) >>> schema = StructType([ - ... StructField("byte", ByteType(), False), - ... StructField("short", ShortType(), False), + ... StructField("byte1", ByteType(), False), + ... StructField("byte2", ByteType(), False), + ... StructField("short1", ShortType(), False), + ... StructField("short2", ShortType(), False), + ... StructField("int", IntegerType(), False), ... StructField("float", FloatType(), False), ... StructField("time", TimestampType(), False), ... StructField("map", @@ -1077,11 +1080,19 @@ def applySchema(self, rdd, schema): ... StructType([StructField("b", ShortType(), False)]), False), ... StructField("list", ArrayType(ByteType(), False), False), ... StructField("null", DoubleType(), True)]) - >>> srdd = sqlCtx.applySchema(rdd, schema).map( - ... lambda x: (x.byte, x.short, x.float, x.time, + >>> srdd = sqlCtx.applySchema(rdd, schema) + >>> results = srdd.map( + ... lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int, x.float, x.time, ... x.map["a"], x.struct.b, x.list, x.null)) - >>> srdd.collect()[0] - (127, -32768, 1.0, ...(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None) + >>> results.collect()[0] + (127, -128, -32768, 32767, 2147483647, 1.0, ...(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None) + + >>> srdd.registerTempTable("table2") + >>> sqlCtx.sql( + ... "SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " + + ... "short1 + 1 AS short1, short2 - 1 AS short2, int - 1 AS int, " + + ... "float + 1.1 as float FROM table2").collect() + [Row(byte1=126, byte2=-127, short1=-32767, short2=32766, int=2147483646, float=2.1)] >>> rdd = sc.parallelize([(127, -32768, 1.0, ... datetime(2010, 1, 1, 1, 1, 1), diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala index f60bbb4662af1..84b57cd2dc1af 100644 --- a/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala +++ b/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala @@ -102,7 +102,8 @@ import org.apache.spark.util.Utils val virtualDirectory = new PlainFile(outputDir) // "directory" for classfiles /** Jetty server that will serve our classes to worker nodes */ - val classServer = new HttpServer(outputDir, new SecurityManager(conf)) + val classServerPort = conf.getInt("spark.replClassServer.port", 0) + val classServer = new HttpServer(outputDir, new SecurityManager(conf), classServerPort, "HTTP class server") private var currentSettings: Settings = initialSettings var printResults = true // whether to print result lines var totalSilence = false // whether to print anything diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 2ba68cab115fb..c18d7858f0a43 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -48,6 +48,7 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool Batch("Resolution", fixedPoint, ResolveReferences :: ResolveRelations :: + ResolveSortReferences :: NewRelationInstances :: ImplicitGenerate :: StarExpansion :: @@ -113,13 +114,58 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool q transformExpressions { case u @ UnresolvedAttribute(name) => // Leave unchanged if resolution fails. Hopefully will be resolved next round. - val result = q.resolve(name).getOrElse(u) + val result = q.resolveChildren(name).getOrElse(u) logDebug(s"Resolving $u to $result") result } } } + /** + * In many dialects of SQL is it valid to sort by attributes that are not present in the SELECT + * clause. This rule detects such queries and adds the required attributes to the original + * projection, so that they will be available during sorting. Another projection is added to + * remove these attributes after sorting. + */ + object ResolveSortReferences extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + case s @ Sort(ordering, p @ Project(projectList, child)) if !s.resolved && p.resolved => + val unresolved = ordering.flatMap(_.collect { case UnresolvedAttribute(name) => name }) + val resolved = unresolved.flatMap(child.resolveChildren) + val requiredAttributes = resolved.collect { case a: Attribute => a }.toSet + + val missingInProject = requiredAttributes -- p.output + if (missingInProject.nonEmpty) { + // Add missing attributes and then project them away after the sort. + Project(projectList, + Sort(ordering, + Project(projectList ++ missingInProject, child))) + } else { + s // Nothing we can do here. Return original plan. + } + case s @ Sort(ordering, a @ Aggregate(grouping, aggs, child)) if !s.resolved && a.resolved => + val unresolved = ordering.flatMap(_.collect { case UnresolvedAttribute(name) => name }) + // A small hack to create an object that will allow us to resolve any references that + // refer to named expressions that are present in the grouping expressions. + val groupingRelation = LocalRelation( + grouping.collect { case ne: NamedExpression => ne.toAttribute } + ) + + logDebug(s"Grouping expressions: $groupingRelation") + val resolved = unresolved.flatMap(groupingRelation.resolve).toSet + val missingInAggs = resolved -- a.outputSet + logDebug(s"Resolved: $resolved Missing in aggs: $missingInAggs") + if (missingInAggs.nonEmpty) { + // Add missing grouping exprs and then project them away after the sort. + Project(a.output, + Sort(ordering, + Aggregate(grouping, aggs ++ missingInAggs, child))) + } else { + s // Nothing we can do here. Return original plan. + } + } + } + /** * Replaces [[UnresolvedFunction]]s with concrete [[catalyst.expressions.Expression Expressions]]. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 888cb08e95f06..278569f0cb14a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -72,16 +72,29 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] { def childrenResolved: Boolean = !children.exists(!_.resolved) /** - * Optionally resolves the given string to a [[NamedExpression]]. The attribute is expressed as + * Optionally resolves the given string to a [[NamedExpression]] using the input from all child + * nodes of this LogicalPlan. The attribute is expressed as * as string in the following form: `[scope].AttributeName.[nested].[fields]...`. */ - def resolve(name: String): Option[NamedExpression] = { + def resolveChildren(name: String): Option[NamedExpression] = + resolve(name, children.flatMap(_.output)) + + /** + * Optionally resolves the given string to a [[NamedExpression]] based on the output of this + * LogicalPlan. The attribute is expressed as string in the following form: + * `[scope].AttributeName.[nested].[fields]...`. + */ + def resolve(name: String): Option[NamedExpression] = + resolve(name, output) + + /** Performs attribute resolution given a name and a sequence of possible attributes. */ + protected def resolve(name: String, input: Seq[Attribute]): Option[NamedExpression] = { val parts = name.split("\\.") // Collect all attributes that are output by this nodes children where either the first part // matches the name or where the first part matches the scope and the second part matches the // name. Return these matches along with any remaining parts, which represent dotted access to // struct fields. - val options = children.flatMap(_.output).flatMap { option => + val options = input.flatMap { option => // If the first part of the desired name matches a qualifier for this possible match, drop it. val remainingParts = if (option.qualifiers.contains(parts.head) && parts.size > 1) parts.drop(1) else parts @@ -89,15 +102,15 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] { } options.distinct match { - case (a, Nil) :: Nil => Some(a) // One match, no nested fields, use it. + case Seq((a, Nil)) => Some(a) // One match, no nested fields, use it. // One match, but we also need to extract the requested nested field. - case (a, nestedFields) :: Nil => + case Seq((a, nestedFields)) => a.dataType match { case StructType(fields) => Some(Alias(nestedFields.foldLeft(a: Expression)(GetField), nestedFields.last)()) case _ => None // Don't know how to resolve these field references } - case Nil => None // No matches. + case Seq() => None // No matches. case ambiguousReferences => throw new TreeNodeException( this, s"Ambiguous references to $name: ${ambiguousReferences.mkString(",")}") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 40bfd55e95a12..0fd7aaaa36eb8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -17,17 +17,17 @@ package org.apache.spark.sql +import scala.collection.immutable +import scala.collection.JavaConversions._ + import java.util.Properties -import scala.collection.JavaConverters._ -object SQLConf { +private[spark] object SQLConf { val COMPRESS_CACHED = "spark.sql.inMemoryColumnarStorage.compressed" val AUTO_BROADCASTJOIN_THRESHOLD = "spark.sql.autoBroadcastJoinThreshold" val DEFAULT_SIZE_IN_BYTES = "spark.sql.defaultSizeInBytes" - val AUTO_CONVERT_JOIN_SIZE = "spark.sql.auto.convert.join.size" val SHUFFLE_PARTITIONS = "spark.sql.shuffle.partitions" - val JOIN_BROADCAST_TABLES = "spark.sql.join.broadcastTables" val CODEGEN_ENABLED = "spark.sql.codegen" val DIALECT = "spark.sql.dialect" @@ -66,13 +66,13 @@ trait SQLConf { * Note that the choice of dialect does not affect things like what tables are available or * how query execution is performed. */ - private[spark] def dialect: String = get(DIALECT, "sql") + private[spark] def dialect: String = getConf(DIALECT, "sql") /** When true tables cached using the in-memory columnar caching will be compressed. */ - private[spark] def useCompression: Boolean = get(COMPRESS_CACHED, "false").toBoolean + private[spark] def useCompression: Boolean = getConf(COMPRESS_CACHED, "false").toBoolean /** Number of partitions to use for shuffle operators. */ - private[spark] def numShufflePartitions: Int = get(SHUFFLE_PARTITIONS, "200").toInt + private[spark] def numShufflePartitions: Int = getConf(SHUFFLE_PARTITIONS, "200").toInt /** * When set to true, Spark SQL will use the Scala compiler at runtime to generate custom bytecode @@ -84,7 +84,7 @@ trait SQLConf { * Defaults to false as this feature is currently experimental. */ private[spark] def codegenEnabled: Boolean = - if (get(CODEGEN_ENABLED, "false") == "true") true else false + if (getConf(CODEGEN_ENABLED, "false") == "true") true else false /** * Upper bound on the sizes (in bytes) of the tables qualified for the auto conversion to @@ -94,7 +94,7 @@ trait SQLConf { * Hive setting: hive.auto.convert.join.noconditionaltask.size, whose default value is also 10000. */ private[spark] def autoBroadcastJoinThreshold: Int = - get(AUTO_BROADCASTJOIN_THRESHOLD, "10000").toInt + getConf(AUTO_BROADCASTJOIN_THRESHOLD, "10000").toInt /** * The default size in bytes to assign to a logical operator's estimation statistics. By default, @@ -102,41 +102,40 @@ trait SQLConf { * properly implemented estimation of this statistic will not be incorrectly broadcasted in joins. */ private[spark] def defaultSizeInBytes: Long = - getOption(DEFAULT_SIZE_IN_BYTES).map(_.toLong).getOrElse(autoBroadcastJoinThreshold + 1) + getConf(DEFAULT_SIZE_IN_BYTES, (autoBroadcastJoinThreshold + 1).toString).toLong /** ********************** SQLConf functionality methods ************ */ - def set(props: Properties): Unit = { - settings.synchronized { - props.asScala.foreach { case (k, v) => settings.put(k, v) } - } + /** Set Spark SQL configuration properties. */ + def setConf(props: Properties): Unit = settings.synchronized { + props.foreach { case (k, v) => settings.put(k, v) } } - def set(key: String, value: String): Unit = { + /** Set the given Spark SQL configuration property. */ + def setConf(key: String, value: String): Unit = { require(key != null, "key cannot be null") require(value != null, s"value cannot be null for key: $key") settings.put(key, value) } - def get(key: String): String = { + /** Return the value of Spark SQL configuration property for the given key. */ + def getConf(key: String): String = { Option(settings.get(key)).getOrElse(throw new NoSuchElementException(key)) } - def get(key: String, defaultValue: String): String = { + /** + * Return the value of Spark SQL configuration property for the given key. If the key is not set + * yet, return `defaultValue`. + */ + def getConf(key: String, defaultValue: String): String = { Option(settings.get(key)).getOrElse(defaultValue) } - def getAll: Array[(String, String)] = settings.synchronized { settings.asScala.toArray } - - def getOption(key: String): Option[String] = Option(settings.get(key)) - - def contains(key: String): Boolean = settings.containsKey(key) - - def toDebugString: String = { - settings.synchronized { - settings.asScala.toArray.sorted.map{ case (k, v) => s"$k=$v" }.mkString("\n") - } - } + /** + * Return all the configuration properties that have been set (i.e. not the default). + * This creates a new copy of the config properties in the form of a Map. + */ + def getAllConfs: immutable.Map[String, String] = settings.synchronized { settings.toMap } private[spark] def clear() { settings.clear() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index ecd5fbaa0b094..71d338d21d0f2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -491,7 +491,10 @@ class SQLContext(@transient val sparkContext: SparkContext) new java.sql.Timestamp(c.getTime().getTime()) case (c: Int, ByteType) => c.toByte + case (c: Long, ByteType) => c.toByte case (c: Int, ShortType) => c.toShort + case (c: Long, ShortType) => c.toShort + case (c: Long, IntegerType) => c.toInt case (c: Double, FloatType) => c.toFloat case (c, StringType) if !c.isInstanceOf[String] => c.toString diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala index c416a745739b3..7e7bb2859bbcd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala @@ -118,7 +118,7 @@ private[sql] class BinaryColumnBuilder extends ComplexColumnBuilder(BINARY) private[sql] class GenericColumnBuilder extends ComplexColumnBuilder(GENERIC) private[sql] object ColumnBuilder { - val DEFAULT_INITIAL_BUFFER_SIZE = 10 * 1024 * 104 + val DEFAULT_INITIAL_BUFFER_SIZE = 1024 * 1024 private[columnar] def ensureFreeSpace(orig: ByteBuffer, size: Int) = { if (orig.remaining >= size) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala index d008806eedbe1..f631ee76fcd78 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala @@ -36,9 +36,9 @@ import org.apache.spark.sql.Row * }}} */ private[sql] trait NullableColumnBuilder extends ColumnBuilder { - private var nulls: ByteBuffer = _ + protected var nulls: ByteBuffer = _ + protected var nullCount: Int = _ private var pos: Int = _ - private var nullCount: Int = _ abstract override def initialize(initialSize: Int, columnName: String, useCompression: Boolean) { nulls = ByteBuffer.allocate(1024) @@ -78,4 +78,9 @@ private[sql] trait NullableColumnBuilder extends ColumnBuilder { buffer.rewind() buffer } + + protected def buildNonNulls(): ByteBuffer = { + nulls.limit(nulls.position()).rewind() + super.build() + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala index 6ad12a0dcb64d..a5826bb033e41 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala @@ -46,8 +46,6 @@ private[sql] trait CompressibleColumnBuilder[T <: NativeType] this: NativeColumnBuilder[T] with WithCompressionSchemes => - import CompressionScheme._ - var compressionEncoders: Seq[Encoder[T]] = _ abstract override def initialize(initialSize: Int, columnName: String, useCompression: Boolean) { @@ -81,28 +79,32 @@ private[sql] trait CompressibleColumnBuilder[T <: NativeType] } } - abstract override def build() = { - val rawBuffer = super.build() + override def build() = { + val nonNullBuffer = buildNonNulls() + val typeId = nonNullBuffer.getInt() val encoder: Encoder[T] = { val candidate = compressionEncoders.minBy(_.compressionRatio) if (isWorthCompressing(candidate)) candidate else PassThrough.encoder } - val headerSize = columnHeaderSize(rawBuffer) + // Header = column type ID + null count + null positions + val headerSize = 4 + 4 + nulls.limit() val compressedSize = if (encoder.compressedSize == 0) { - rawBuffer.limit - headerSize + nonNullBuffer.remaining() } else { encoder.compressedSize } - // Reserves 4 bytes for compression scheme ID val compressedBuffer = ByteBuffer + // Reserves 4 bytes for compression scheme ID .allocate(headerSize + 4 + compressedSize) .order(ByteOrder.nativeOrder) - - copyColumnHeader(rawBuffer, compressedBuffer) + // Write the header + .putInt(typeId) + .putInt(nullCount) + .put(nulls) logInfo(s"Compressor for [$columnName]: $encoder, ratio: ${encoder.compressionRatio}") - encoder.compress(rawBuffer, compressedBuffer, columnType) + encoder.compress(nonNullBuffer, compressedBuffer, columnType) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala index ba1810dd2ae66..7797f75177893 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala @@ -67,22 +67,6 @@ private[sql] object CompressionScheme { s"Unrecognized compression scheme type ID: $typeId")) } - def copyColumnHeader(from: ByteBuffer, to: ByteBuffer) { - // Writes column type ID - to.putInt(from.getInt()) - - // Writes null count - val nullCount = from.getInt() - to.putInt(nullCount) - - // Writes null positions - var i = 0 - while (i < nullCount) { - to.putInt(from.getInt()) - i += 1 - } - } - def columnHeaderSize(columnBuffer: ByteBuffer): Int = { val header = columnBuffer.duplicate().order(ByteOrder.nativeOrder) val nullCount = header.getInt(4) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index 9293239131d52..38f37564f1788 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -53,10 +53,10 @@ case class SetCommand( if (k == SQLConf.Deprecated.MAPRED_REDUCE_TASKS) { logWarning(s"Property ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} is deprecated, " + s"automatically converted to ${SQLConf.SHUFFLE_PARTITIONS} instead.") - context.set(SQLConf.SHUFFLE_PARTITIONS, v) + context.setConf(SQLConf.SHUFFLE_PARTITIONS, v) Array(s"${SQLConf.SHUFFLE_PARTITIONS}=$v") } else { - context.set(k, v) + context.setConf(k, v) Array(s"$k=$v") } @@ -77,14 +77,14 @@ case class SetCommand( "system:sun.java.command=shark.SharkServer2") } else { - Array(s"$k=${context.getOption(k).getOrElse("")}") + Array(s"$k=${context.getConf(k, "")}") } // Query all key-value pairs that are set in the SQLConf of the context. case (None, None) => - context.getAll.map { case (k, v) => + context.getAllConfs.map { case (k, v) => s"$k=$v" - } + }.toSeq case _ => throw new IllegalArgumentException() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala index 1a58d73d9e7f4..584f71b3c13d5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala @@ -29,21 +29,18 @@ class SQLConfSuite extends QueryTest { test("programmatic ways of basic setting and getting") { clear() - assert(getOption(testKey).isEmpty) - assert(getAll.toSet === Set()) + assert(getAllConfs.size === 0) - set(testKey, testVal) - assert(get(testKey) == testVal) - assert(get(testKey, testVal + "_") == testVal) - assert(getOption(testKey) == Some(testVal)) - assert(contains(testKey)) + setConf(testKey, testVal) + assert(getConf(testKey) == testVal) + assert(getConf(testKey, testVal + "_") == testVal) + assert(getAllConfs.contains(testKey)) // Tests SQLConf as accessed from a SQLContext is mutable after // the latter is initialized, unlike SparkConf inside a SparkContext. - assert(TestSQLContext.get(testKey) == testVal) - assert(TestSQLContext.get(testKey, testVal + "_") == testVal) - assert(TestSQLContext.getOption(testKey) == Some(testVal)) - assert(TestSQLContext.contains(testKey)) + assert(TestSQLContext.getConf(testKey) == testVal) + assert(TestSQLContext.getConf(testKey, testVal + "_") == testVal) + assert(TestSQLContext.getAllConfs.contains(testKey)) clear() } @@ -51,21 +48,21 @@ class SQLConfSuite extends QueryTest { test("parse SQL set commands") { clear() sql(s"set $testKey=$testVal") - assert(get(testKey, testVal + "_") == testVal) - assert(TestSQLContext.get(testKey, testVal + "_") == testVal) + assert(getConf(testKey, testVal + "_") == testVal) + assert(TestSQLContext.getConf(testKey, testVal + "_") == testVal) sql("set some.property=20") - assert(get("some.property", "0") == "20") + assert(getConf("some.property", "0") == "20") sql("set some.property = 40") - assert(get("some.property", "0") == "40") + assert(getConf("some.property", "0") == "40") val key = "spark.sql.key" val vs = "val0,val_1,val2.3,my_table" sql(s"set $key=$vs") - assert(get(key, "0") == vs) + assert(getConf(key, "0") == vs) sql(s"set $key=") - assert(get(key, "0") == "") + assert(getConf(key, "0") == "") clear() } @@ -73,6 +70,6 @@ class SQLConfSuite extends QueryTest { test("deprecated property") { clear() sql(s"set ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}=10") - assert(get(SQLConf.SHUFFLE_PARTITIONS) == "10") + assert(getConf(SQLConf.SHUFFLE_PARTITIONS) == "10") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala index 6d688ea95cfc0..72c19fa31d980 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala @@ -42,4 +42,3 @@ object TestCompressibleColumnBuilder { builder } } - diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index 7fac90fdc596d..c6f60c18804a4 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -29,7 +29,7 @@ org.apache.spark spark-hive-thriftserver_2.10 jar - Spark Project Hive + Spark Project Hive Thrift Server http://spark.apache.org/ hive-thriftserver diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index d8e7a5943daa5..53f3dc11dbb9f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -60,9 +60,9 @@ class LocalHiveContext(sc: SparkContext) extends HiveContext(sc) { /** Sets up the system initially or after a RESET command */ protected def configure() { - set("javax.jdo.option.ConnectionURL", + setConf("javax.jdo.option.ConnectionURL", s"jdbc:derby:;databaseName=$metastorePath;create=true") - set("hive.metastore.warehouse.dir", warehousePath) + setConf("hive.metastore.warehouse.dir", warehousePath) } configure() // Must be called before initializing the catalog below. @@ -76,7 +76,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { self => // Change the default SQL dialect to HiveQL - override private[spark] def dialect: String = get(SQLConf.DIALECT, "hiveql") + override private[spark] def dialect: String = getConf(SQLConf.DIALECT, "hiveql") override protected[sql] def executePlan(plan: LogicalPlan): this.QueryExecution = new this.QueryExecution { val logical = plan } @@ -224,15 +224,15 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { @transient protected[hive] lazy val hiveconf = new HiveConf(classOf[SessionState]) @transient protected[hive] lazy val sessionState = { val ss = new SessionState(hiveconf) - set(hiveconf.getAllProperties) // Have SQLConf pick up the initial set of HiveConf. + setConf(hiveconf.getAllProperties) // Have SQLConf pick up the initial set of HiveConf. ss } sessionState.err = new PrintStream(outputBuffer, true, "UTF-8") sessionState.out = new PrintStream(outputBuffer, true, "UTF-8") - override def set(key: String, value: String): Unit = { - super.set(key, value) + override def setConf(key: String, value: String): Unit = { + super.setConf(key, value) runSqlHive(s"SET $key=$value") } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala index c605e8adcfb0f..d890df866fbe5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala @@ -65,9 +65,9 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { /** Sets up the system initially or after a RESET command */ protected def configure() { - set("javax.jdo.option.ConnectionURL", + setConf("javax.jdo.option.ConnectionURL", s"jdbc:derby:;databaseName=$metastorePath;create=true") - set("hive.metastore.warehouse.dir", warehousePath) + setConf("hive.metastore.warehouse.dir", warehousePath) } configure() // Must be called before initializing the catalog below. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 2f0be49b6a6d7..fdb2f41f5a5b6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -75,9 +75,9 @@ class HiveQuerySuite extends HiveComparisonTest { "SELECT 2 / 1, 1 / 2, 1 / 3, 1 / COUNT(*) FROM src LIMIT 1") test("Query expressed in SQL") { - set("spark.sql.dialect", "sql") + setConf("spark.sql.dialect", "sql") assert(sql("SELECT 1").collect() === Array(Seq(1))) - set("spark.sql.dialect", "hiveql") + setConf("spark.sql.dialect", "hiveql") } @@ -436,18 +436,18 @@ class HiveQuerySuite extends HiveComparisonTest { val testVal = "val0,val_1,val2.3,my_table" sql(s"set $testKey=$testVal") - assert(get(testKey, testVal + "_") == testVal) + assert(getConf(testKey, testVal + "_") == testVal) sql("set some.property=20") - assert(get("some.property", "0") == "20") + assert(getConf("some.property", "0") == "20") sql("set some.property = 40") - assert(get("some.property", "0") == "40") + assert(getConf("some.property", "0") == "40") sql(s"set $testKey=$testVal") - assert(get(testKey, "0") == testVal) + assert(getConf(testKey, "0") == testVal) sql(s"set $testKey=") - assert(get(testKey, "0") == "") + assert(getConf(testKey, "0") == "") } test("SET commands semantics for a HiveContext") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala new file mode 100644 index 0000000000000..635a9fb0d56cb --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import scala.reflect.ClassTag + +import org.apache.spark.sql.{SQLConf, QueryTest} +import org.apache.spark.sql.execution.{BroadcastHashJoin, ShuffledHashJoin} +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHive._ + +/** + * A collection of hive query tests where we generate the answers ourselves instead of depending on + * Hive to generate them (in contrast to HiveQuerySuite). Often this is because the query is + * valid, but Hive currently cannot execute it. + */ +class SQLQuerySuite extends QueryTest { + test("ordering not in select") { + checkAnswer( + sql("SELECT key FROM src ORDER BY value"), + sql("SELECT key FROM (SELECT key, value FROM src ORDER BY value) a").collect().toSeq) + } + + test("ordering not in agg") { + checkAnswer( + sql("SELECT key FROM src GROUP BY key, value ORDER BY value"), + sql(""" + SELECT key + FROM ( + SELECT key, value + FROM src + GROUP BY key, value + ORDER BY value) a""").collect().toSeq) + } +}