Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
Browse files

SPARK-1189: Add Security to Spark - Akka, Http, ConnectionManager, UI…

… use servlets

resubmit pull request.  was https://github.com/apache/incubator-spark/pull/332.

Author: Thomas Graves <tgraves@apache.org>

Closes #33 from tgravescs/security-branch-0.9-with-client-rebase and squashes the following commits:

dfe3918 [Thomas Graves] Fix merge conflict since startUserClass now using runAsUser
05eebed [Thomas Graves] Fix dependency lost in upmerge
d1040ec [Thomas Graves] Fix up various imports
05ff5e0 [Thomas Graves] Fix up imports after upmerging to master
ac046b3 [Thomas Graves] Merge remote-tracking branch 'upstream/master' into security-branch-0.9-with-client-rebase
13733e1 [Thomas Graves] Pass securityManager and SparkConf around where we can. Switch to use sparkConf for reading config whereever possible. Added ConnectionManagerSuite unit tests.
4a57acc [Thomas Graves] Change UI createHandler routines to createServlet since they now return servlets
2f77147 [Thomas Graves] Rework from comments
50dd9f2 [Thomas Graves] fix header in SecurityManager
ecbfb65 [Thomas Graves] Fix spacing and formatting
b514bec [Thomas Graves] Fix reference to config
ed3d1c1 [Thomas Graves] Add security.md
6f7ddf3 [Thomas Graves] Convert SaslClient and SaslServer to scala, change spark.authenticate.ui to spark.ui.acls.enable, and fix up various other things from review comments
2d9e23e [Thomas Graves] Merge remote-tracking branch 'upstream/master' into security-branch-0.9-with-client-rebase_rework
5721c5a [Thomas Graves] update AkkaUtilsSuite test for the actorSelection changes, fix typos based on comments, and remove extra lines I missed in rebase from AkkaUtils
f351763 [Thomas Graves] Add Security to Spark - Akka, Http, ConnectionManager, UI to use servlets
  • Loading branch information...
commit 7edbea41b43e0dc11a2de156be220db8b7952d01 1 parent 40566e1
@tgravescs tgravescs authored
Showing with 2,251 additions and 292 deletions.
  1. +16 −0 core/pom.xml
  2. +3 −2 core/src/main/scala/org/apache/spark/HttpFileServer.scala
  3. +55 −5 core/src/main/scala/org/apache/spark/HttpServer.scala
  4. +253 −0 core/src/main/scala/org/apache/spark/SecurityManager.scala
  5. +3 −1 core/src/main/scala/org/apache/spark/SparkContext.scala
  6. +13 −11 core/src/main/scala/org/apache/spark/SparkEnv.scala
  7. +146 −0 core/src/main/scala/org/apache/spark/SparkSaslClient.scala
  8. +174 −0 core/src/main/scala/org/apache/spark/SparkSaslServer.scala
  9. +3 −2 core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala
  10. +2 −1  core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala
  11. +24 −8 core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
  12. +3 −1 core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
  13. +2 −2 core/src/main/scala/org/apache/spark/deploy/Client.scala
  14. +10 −0 core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
  15. +3 −2 core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala
  16. +11 −6 core/src/main/scala/org/apache/spark/deploy/master/Master.scala
  17. +15 −10 core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala
  18. +3 −2 core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala
  19. +7 −5 core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
  20. +15 −11 core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala
  21. +3 −2 core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
  22. +8 −7 core/src/main/scala/org/apache/spark/executor/Executor.scala
  23. +7 −6 core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala
  24. +3 −1 core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala
  25. +3 −1 core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala
  26. +3 −1 core/src/main/scala/org/apache/spark/metrics/sink/GangliaSink.scala
  27. +3 −1 core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala
  28. +4 −1 core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala
  29. +10 −4 core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala
  30. +5 −3 core/src/main/scala/org/apache/spark/network/BufferMessage.scala
  31. +52 −9 core/src/main/scala/org/apache/spark/network/Connection.scala
  32. +34 −0 core/src/main/scala/org/apache/spark/network/ConnectionId.scala
  33. +258 −8 core/src/main/scala/org/apache/spark/network/ConnectionManager.scala
  34. +1 −0  core/src/main/scala/org/apache/spark/network/Message.scala
  35. +8 −3 core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala
  36. +3 −3 core/src/main/scala/org/apache/spark/network/ReceiverTest.scala
  37. +163 −0 core/src/main/scala/org/apache/spark/network/SecurityMessage.scala
  38. +3 −4 core/src/main/scala/org/apache/spark/network/SenderTest.scala
  39. +7 −5 core/src/main/scala/org/apache/spark/storage/BlockManager.scala
  40. +3 −1 core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala
  41. +93 −45 core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
  42. +9 −6 core/src/main/scala/org/apache/spark/ui/SparkUI.scala
  43. +4 −3 core/src/main/scala/org/apache/spark/ui/env/EnvironmentUI.scala
  44. +4 −3 core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala
  45. +11 −4 core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala
  46. +8 −4 core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala
  47. +14 −3 core/src/main/scala/org/apache/spark/util/AkkaUtils.scala
  48. +33 −4 core/src/main/scala/org/apache/spark/util/Utils.scala
  49. +215 −0 core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala
  50. +1 −0  core/src/test/scala/org/apache/spark/BroadcastSuite.scala
  51. +230 −0 core/src/test/scala/org/apache/spark/ConnectionManagerSuite.scala
  52. +1 −0  core/src/test/scala/org/apache/spark/DriverSuite.scala
  53. +26 −0 core/src/test/scala/org/apache/spark/FileServerSuite.scala
  54. +4 −2 core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
  55. +5 −4 core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala
  56. +36 −31 core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
  57. +7 −3 core/src/test/scala/org/apache/spark/ui/UISuite.scala
  58. +51 −0 docs/configuration.md
  59. +1 −0  docs/index.md
  60. +18 −0 docs/security.md
  61. +4 −3 examples/src/main/scala/org/apache/spark/streaming/examples/ActorWordCount.scala
  62. +20 −0 pom.xml
  63. +4 −0 project/SparkBuild.scala
  64. +11 −2 repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala
  65. +14 −8 repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala
  66. +8 −5 repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala
  67. +23 −21 yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
  68. +4 −2 yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala
  69. +1 −1  yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala
  70. +23 −1 yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala
  71. +22 −6 yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
  72. +4 −2 yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala
View
16 core/pom.xml
@@ -66,6 +66,18 @@
</dependency>
<dependency>
<groupId>org.eclipse.jetty</groupId>
+ <artifactId>jetty-plus</artifactId>
+ </dependency>
+ <dependency>
+ <groupId>org.eclipse.jetty</groupId>
+ <artifactId>jetty-security</artifactId>
+ </dependency>
+ <dependency>
+ <groupId>org.eclipse.jetty</groupId>
+ <artifactId>jetty-util</artifactId>
+ </dependency>
+ <dependency>
+ <groupId>org.eclipse.jetty</groupId>
<artifactId>jetty-server</artifactId>
</dependency>
<dependency>
@@ -119,6 +131,10 @@
<version>0.3.1</version>
</dependency>
<dependency>
+ <groupId>commons-net</groupId>
+ <artifactId>commons-net</artifactId>
+ </dependency>
+ <dependency>
<groupId>${akka.group}</groupId>
<artifactId>akka-remote_${scala.binary.version}</artifactId>
</dependency>
View
5 core/src/main/scala/org/apache/spark/HttpFileServer.scala
@@ -23,7 +23,7 @@ import com.google.common.io.Files
import org.apache.spark.util.Utils
-private[spark] class HttpFileServer extends Logging {
+private[spark] class HttpFileServer(securityManager: SecurityManager) extends Logging {
var baseDir : File = null
var fileDir : File = null
@@ -38,9 +38,10 @@ private[spark] class HttpFileServer extends Logging {
fileDir.mkdir()
jarDir.mkdir()
logInfo("HTTP File server directory is " + baseDir)
- httpServer = new HttpServer(baseDir)
+ httpServer = new HttpServer(baseDir, securityManager)
httpServer.start()
serverUri = httpServer.uri
+ logDebug("HTTP file server started at: " + serverUri)
}
def stop() {
View
60 core/src/main/scala/org/apache/spark/HttpServer.scala
@@ -19,15 +19,18 @@ package org.apache.spark
import java.io.File
+import org.eclipse.jetty.util.security.{Constraint, Password}
+import org.eclipse.jetty.security.authentication.DigestAuthenticator
+import org.eclipse.jetty.security.{ConstraintMapping, ConstraintSecurityHandler, HashLoginService, SecurityHandler}
+
import org.eclipse.jetty.server.Server
import org.eclipse.jetty.server.bio.SocketConnector
-import org.eclipse.jetty.server.handler.DefaultHandler
-import org.eclipse.jetty.server.handler.HandlerList
-import org.eclipse.jetty.server.handler.ResourceHandler
+import org.eclipse.jetty.server.handler.{DefaultHandler, HandlerList, ResourceHandler}
import org.eclipse.jetty.util.thread.QueuedThreadPool
import org.apache.spark.util.Utils
+
/**
* Exception type thrown by HttpServer when it is in the wrong state for an operation.
*/
@@ -38,7 +41,8 @@ private[spark] class ServerStateException(message: String) extends Exception(mes
* as well as classes created by the interpreter when the user types in code. This is just a wrapper
* around a Jetty server.
*/
-private[spark] class HttpServer(resourceBase: File) extends Logging {
+private[spark] class HttpServer(resourceBase: File, securityManager: SecurityManager)
+ extends Logging {
private var server: Server = null
private var port: Int = -1
@@ -59,14 +63,60 @@ private[spark] class HttpServer(resourceBase: File) extends Logging {
server.setThreadPool(threadPool)
val resHandler = new ResourceHandler
resHandler.setResourceBase(resourceBase.getAbsolutePath)
+
val handlerList = new HandlerList
handlerList.setHandlers(Array(resHandler, new DefaultHandler))
- server.setHandler(handlerList)
+
+ if (securityManager.isAuthenticationEnabled()) {
+ logDebug("HttpServer is using security")
+ val sh = setupSecurityHandler(securityManager)
+ // make sure we go through security handler to get resources
+ sh.setHandler(handlerList)
+ server.setHandler(sh)
+ } else {
+ logDebug("HttpServer is not using security")
+ server.setHandler(handlerList)
+ }
+
server.start()
port = server.getConnectors()(0).getLocalPort()
}
}
+ /**
+ * Setup Jetty to the HashLoginService using a single user with our
+ * shared secret. Configure it to use DIGEST-MD5 authentication so that the password
+ * isn't passed in plaintext.
+ */
+ private def setupSecurityHandler(securityMgr: SecurityManager): ConstraintSecurityHandler = {
+ val constraint = new Constraint()
+ // use DIGEST-MD5 as the authentication mechanism
+ constraint.setName(Constraint.__DIGEST_AUTH)
+ constraint.setRoles(Array("user"))
+ constraint.setAuthenticate(true)
+ constraint.setDataConstraint(Constraint.DC_NONE)
+
+ val cm = new ConstraintMapping()
+ cm.setConstraint(constraint)
+ cm.setPathSpec("/*")
+ val sh = new ConstraintSecurityHandler()
+
+ // the hashLoginService lets us do a single user and
+ // secret right now. This could be changed to use the
+ // JAASLoginService for other options.
+ val hashLogin = new HashLoginService()
+
+ val userCred = new Password(securityMgr.getSecretKey())
+ if (userCred == null) {
+ throw new Exception("Error: secret key is null with authentication on")
+ }
+ hashLogin.putUser(securityMgr.getHttpUser(), userCred, Array("user"))
+ sh.setLoginService(hashLogin)
+ sh.setAuthenticator(new DigestAuthenticator());
+ sh.setConstraintMappings(Array(cm))
+ sh
+ }
+
def stop() {
if (server == null) {
throw new ServerStateException("Server is already stopped")
View
253 core/src/main/scala/org/apache/spark/SecurityManager.scala
@@ -0,0 +1,253 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import java.net.{Authenticator, PasswordAuthentication}
+import org.apache.hadoop.io.Text
+import org.apache.hadoop.security.Credentials
+import org.apache.hadoop.security.UserGroupInformation
+import org.apache.spark.deploy.SparkHadoopUtil
+
+import scala.collection.mutable.ArrayBuffer
+
+/**
+ * Spark class responsible for security.
+ *
+ * In general this class should be instantiated by the SparkEnv and most components
+ * should access it from that. There are some cases where the SparkEnv hasn't been
+ * initialized yet and this class must be instantiated directly.
+ *
+ * Spark currently supports authentication via a shared secret.
+ * Authentication can be configured to be on via the 'spark.authenticate' configuration
+ * parameter. This parameter controls whether the Spark communication protocols do
+ * authentication using the shared secret. This authentication is a basic handshake to
+ * make sure both sides have the same shared secret and are allowed to communicate.
+ * If the shared secret is not identical they will not be allowed to communicate.
+ *
+ * The Spark UI can also be secured by using javax servlet filters. A user may want to
+ * secure the UI if it has data that other users should not be allowed to see. The javax
+ * servlet filter specified by the user can authenticate the user and then once the user
+ * is logged in, Spark can compare that user versus the view acls to make sure they are
+ * authorized to view the UI. The configs 'spark.ui.acls.enable' and 'spark.ui.view.acls'
+ * control the behavior of the acls. Note that the person who started the application
+ * always has view access to the UI.
+ *
+ * Spark does not currently support encryption after authentication.
+ *
+ * At this point spark has multiple communication protocols that need to be secured and
+ * different underlying mechanisms are used depending on the protocol:
+ *
+ * - Akka -> The only option here is to use the Akka Remote secure-cookie functionality.
+ * Akka remoting allows you to specify a secure cookie that will be exchanged
+ * and ensured to be identical in the connection handshake between the client
+ * and the server. If they are not identical then the client will be refused
+ * to connect to the server. There is no control of the underlying
+ * authentication mechanism so its not clear if the password is passed in
+ * plaintext or uses DIGEST-MD5 or some other mechanism.
+ * Akka also has an option to turn on SSL, this option is not currently supported
+ * but we could add a configuration option in the future.
+ *
+ * - HTTP for broadcast and file server (via HttpServer) -> Spark currently uses Jetty
+ * for the HttpServer. Jetty supports multiple authentication mechanisms -
+ * Basic, Digest, Form, Spengo, etc. It also supports multiple different login
+ * services - Hash, JAAS, Spnego, JDBC, etc. Spark currently uses the HashLoginService
+ * to authenticate using DIGEST-MD5 via a single user and the shared secret.
+ * Since we are using DIGEST-MD5, the shared secret is not passed on the wire
+ * in plaintext.
+ * We currently do not support SSL (https), but Jetty can be configured to use it
+ * so we could add a configuration option for this in the future.
+ *
+ * The Spark HttpServer installs the HashLoginServer and configures it to DIGEST-MD5.
+ * Any clients must specify the user and password. There is a default
+ * Authenticator installed in the SecurityManager to how it does the authentication
+ * and in this case gets the user name and password from the request.
+ *
+ * - ConnectionManager -> The Spark ConnectionManager uses java nio to asynchronously
+ * exchange messages. For this we use the Java SASL
+ * (Simple Authentication and Security Layer) API and again use DIGEST-MD5
+ * as the authentication mechanism. This means the shared secret is not passed
+ * over the wire in plaintext.
+ * Note that SASL is pluggable as to what mechanism it uses. We currently use
+ * DIGEST-MD5 but this could be changed to use Kerberos or other in the future.
+ * Spark currently supports "auth" for the quality of protection, which means
+ * the connection is not supporting integrity or privacy protection (encryption)
+ * after authentication. SASL also supports "auth-int" and "auth-conf" which
+ * SPARK could be support in the future to allow the user to specify the quality
+ * of protection they want. If we support those, the messages will also have to
+ * be wrapped and unwrapped via the SaslServer/SaslClient.wrap/unwrap API's.
+ *
+ * Since the connectionManager does asynchronous messages passing, the SASL
+ * authentication is a bit more complex. A ConnectionManager can be both a client
+ * and a Server, so for a particular connection is has to determine what to do.
+ * A ConnectionId was added to be able to track connections and is used to
+ * match up incoming messages with connections waiting for authentication.
+ * If its acting as a client and trying to send a message to another ConnectionManager,
+ * it blocks the thread calling sendMessage until the SASL negotiation has occurred.
+ * The ConnectionManager tracks all the sendingConnections using the ConnectionId
+ * and waits for the response from the server and does the handshake.
+ *
+ * - HTTP for the Spark UI -> the UI was changed to use servlets so that javax servlet filters
+ * can be used. Yarn requires a specific AmIpFilter be installed for security to work
+ * properly. For non-Yarn deployments, users can write a filter to go through a
+ * companies normal login service. If an authentication filter is in place then the
+ * SparkUI can be configured to check the logged in user against the list of users who
+ * have view acls to see if that user is authorized.
+ * The filters can also be used for many different purposes. For instance filters
+ * could be used for logging, encryption, or compression.
+ *
+ * The exact mechanisms used to generate/distributed the shared secret is deployment specific.
+ *
+ * For Yarn deployments, the secret is automatically generated using the Akka remote
+ * Crypt.generateSecureCookie() API. The secret is placed in the Hadoop UGI which gets passed
+ * around via the Hadoop RPC mechanism. Hadoop RPC can be configured to support different levels
+ * of protection. See the Hadoop documentation for more details. Each Spark application on Yarn
+ * gets a different shared secret. On Yarn, the Spark UI gets configured to use the Hadoop Yarn
+ * AmIpFilter which requires the user to go through the ResourceManager Proxy. That Proxy is there
+ * to reduce the possibility of web based attacks through YARN. Hadoop can be configured to use
+ * filters to do authentication. That authentication then happens via the ResourceManager Proxy
+ * and Spark will use that to do authorization against the view acls.
+ *
+ * For other Spark deployments, the shared secret must be specified via the
+ * spark.authenticate.secret config.
+ * All the nodes (Master and Workers) and the applications need to have the same shared secret.
+ * This again is not ideal as one user could potentially affect another users application.
+ * This should be enhanced in the future to provide better protection.
+ * If the UI needs to be secured the user needs to install a javax servlet filter to do the
+ * authentication. Spark will then use that user to compare against the view acls to do
+ * authorization. If not filter is in place the user is generally null and no authorization
+ * can take place.
+ */
+
+private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging {
+
+ // key used to store the spark secret in the Hadoop UGI
+ private val sparkSecretLookupKey = "sparkCookie"
+
+ private val authOn = sparkConf.getBoolean("spark.authenticate", false)
+ private val uiAclsOn = sparkConf.getBoolean("spark.ui.acls.enable", false)
+
+ // always add the current user and SPARK_USER to the viewAcls
+ private val aclUsers = ArrayBuffer[String](System.getProperty("user.name", ""),
+ Option(System.getenv("SPARK_USER")).getOrElse(""))
+ aclUsers ++= sparkConf.get("spark.ui.view.acls", "").split(',')
+ private val viewAcls = aclUsers.map(_.trim()).filter(!_.isEmpty).toSet
+
+ private val secretKey = generateSecretKey()
+ logInfo("SecurityManager, is authentication enabled: " + authOn +
+ " are ui acls enabled: " + uiAclsOn + " users with view permissions: " + viewAcls.toString())
+
+ // Set our own authenticator to properly negotiate user/password for HTTP connections.
+ // This is needed by the HTTP client fetching from the HttpServer. Put here so its
+ // only set once.
+ if (authOn) {
+ Authenticator.setDefault(
+ new Authenticator() {
+ override def getPasswordAuthentication(): PasswordAuthentication = {
+ var passAuth: PasswordAuthentication = null
+ val userInfo = getRequestingURL().getUserInfo()
+ if (userInfo != null) {
+ val parts = userInfo.split(":", 2)
+ passAuth = new PasswordAuthentication(parts(0), parts(1).toCharArray())
+ }
+ return passAuth
+ }
+ }
+ )
+ }
+
+ /**
+ * Generates or looks up the secret key.
+ *
+ * The way the key is stored depends on the Spark deployment mode. Yarn
+ * uses the Hadoop UGI.
+ *
+ * For non-Yarn deployments, If the config variable is not set
+ * we throw an exception.
+ */
+ private def generateSecretKey(): String = {
+ if (!isAuthenticationEnabled) return null
+ // first check to see if the secret is already set, else generate a new one if on yarn
+ val sCookie = if (SparkHadoopUtil.get.isYarnMode) {
+ val secretKey = SparkHadoopUtil.get.getSecretKeyFromUserCredentials(sparkSecretLookupKey)
+ if (secretKey != null) {
+ logDebug("in yarn mode, getting secret from credentials")
+ return new Text(secretKey).toString
+ } else {
+ logDebug("getSecretKey: yarn mode, secret key from credentials is null")
+ }
+ val cookie = akka.util.Crypt.generateSecureCookie
+ // if we generated the secret then we must be the first so lets set it so t
+ // gets used by everyone else
+ SparkHadoopUtil.get.addSecretKeyToUserCredentials(sparkSecretLookupKey, cookie)
+ logInfo("adding secret to credentials in yarn mode")
+ cookie
+ } else {
+ // user must have set spark.authenticate.secret config
+ sparkConf.getOption("spark.authenticate.secret") match {
+ case Some(value) => value
+ case None => throw new Exception("Error: a secret key must be specified via the " +
+ "spark.authenticate.secret config")
+ }
+ }
+ sCookie
+ }
+
+ /**
+ * Check to see if Acls for the UI are enabled
+ * @return true if UI authentication is enabled, otherwise false
+ */
+ def uiAclsEnabled(): Boolean = uiAclsOn
+
+ /**
+ * Checks the given user against the view acl list to see if they have
+ * authorization to view the UI. If the UI acls must are disabled
+ * via spark.ui.acls.enable, all users have view access.
+ *
+ * @param user to see if is authorized
+ * @return true is the user has permission, otherwise false
+ */
+ def checkUIViewPermissions(user: String): Boolean = {
+ if (uiAclsEnabled() && (user != null) && (!viewAcls.contains(user))) false else true
+ }
+
+ /**
+ * Check to see if authentication for the Spark communication protocols is enabled
+ * @return true if authentication is enabled, otherwise false
+ */
+ def isAuthenticationEnabled(): Boolean = authOn
+
+ /**
+ * Gets the user used for authenticating HTTP connections.
+ * For now use a single hardcoded user.
+ * @return the HTTP user as a String
+ */
+ def getHttpUser(): String = "sparkHttpUser"
+
+ /**
+ * Gets the user used for authenticating SASL connections.
+ * For now use a single hardcoded user.
+ * @return the SASL user as a String
+ */
+ def getSaslUser(): String = "sparkSaslUser"
+
+ /**
+ * Gets the secret key.
+ * @return the secret key as a String if authentication is enabled, otherwise returns null
+ */
+ def getSecretKey(): String = secretKey
+}
View
4 core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -130,6 +130,8 @@ class SparkContext(
val isLocal = (master == "local" || master.startsWith("local["))
+ if (master == "yarn-client") System.setProperty("SPARK_YARN_MODE", "true")
+
// Create the Spark execution environment (cache, map output tracker, etc)
private[spark] val env = SparkEnv.create(
conf,
@@ -634,7 +636,7 @@ class SparkContext(
addedFiles(key) = System.currentTimeMillis
// Fetch the file locally in case a job is executed using DAGScheduler.runLocally().
- Utils.fetchFile(path, new File(SparkFiles.getRootDirectory), conf)
+ Utils.fetchFile(path, new File(SparkFiles.getRootDirectory), conf, env.securityManager)
logInfo("Added file " + path + " at " + key + " with timestamp " + addedFiles(key))
}
View
24 core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -53,7 +53,8 @@ class SparkEnv private[spark] (
val httpFileServer: HttpFileServer,
val sparkFilesDir: String,
val metricsSystem: MetricsSystem,
- val conf: SparkConf) extends Logging {
+ val conf: SparkConf,
+ val securityManager: SecurityManager) extends Logging {
// A mapping of thread ID to amount of memory used for shuffle in bytes
// All accesses should be manually synchronized
@@ -122,8 +123,9 @@ object SparkEnv extends Logging {
isDriver: Boolean,
isLocal: Boolean): SparkEnv = {
- val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, port,
- conf = conf)
+ val securityManager = new SecurityManager(conf)
+ val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, port, conf = conf,
+ securityManager = securityManager)
// Bit of a hack: If this is the driver and our port was 0 (meaning bind to any free port),
// figure out which port number Akka actually bound to and set spark.driver.port to it.
@@ -139,7 +141,6 @@ object SparkEnv extends Logging {
val name = conf.get(propertyName, defaultClassName)
Class.forName(name, true, classLoader).newInstance().asInstanceOf[T]
}
-
val serializerManager = new SerializerManager
val serializer = serializerManager.setDefault(
@@ -167,12 +168,12 @@ object SparkEnv extends Logging {
val blockManagerMaster = new BlockManagerMaster(registerOrLookup(
"BlockManagerMaster",
new BlockManagerMasterActor(isLocal, conf)), conf)
- val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster,
- serializer, conf)
+ val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster,
+ serializer, conf, securityManager)
val connectionManager = blockManager.connectionManager
- val broadcastManager = new BroadcastManager(isDriver, conf)
+ val broadcastManager = new BroadcastManager(isDriver, conf, securityManager)
val cacheManager = new CacheManager(blockManager)
@@ -190,14 +191,14 @@ object SparkEnv extends Logging {
val shuffleFetcher = instantiateClass[ShuffleFetcher](
"spark.shuffle.fetcher", "org.apache.spark.BlockStoreShuffleFetcher")
- val httpFileServer = new HttpFileServer()
+ val httpFileServer = new HttpFileServer(securityManager)
httpFileServer.initialize()
conf.set("spark.fileserver.uri", httpFileServer.serverUri)
val metricsSystem = if (isDriver) {
- MetricsSystem.createMetricsSystem("driver", conf)
+ MetricsSystem.createMetricsSystem("driver", conf, securityManager)
} else {
- MetricsSystem.createMetricsSystem("executor", conf)
+ MetricsSystem.createMetricsSystem("executor", conf, securityManager)
}
metricsSystem.start()
@@ -231,6 +232,7 @@ object SparkEnv extends Logging {
httpFileServer,
sparkFilesDir,
metricsSystem,
- conf)
+ conf,
+ securityManager)
}
}
View
146 core/src/main/scala/org/apache/spark/SparkSaslClient.scala
@@ -0,0 +1,146 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import java.io.IOException
+import javax.security.auth.callback.Callback
+import javax.security.auth.callback.CallbackHandler
+import javax.security.auth.callback.NameCallback
+import javax.security.auth.callback.PasswordCallback
+import javax.security.auth.callback.UnsupportedCallbackException
+import javax.security.sasl.RealmCallback
+import javax.security.sasl.RealmChoiceCallback
+import javax.security.sasl.Sasl
+import javax.security.sasl.SaslClient
+import javax.security.sasl.SaslException
+
+import scala.collection.JavaConversions.mapAsJavaMap
+
+/**
+ * Implements SASL Client logic for Spark
+ */
+private[spark] class SparkSaslClient(securityMgr: SecurityManager) extends Logging {
+
+ /**
+ * Used to respond to server's counterpart, SaslServer with SASL tokens
+ * represented as byte arrays.
+ *
+ * The authentication mechanism used here is DIGEST-MD5. This could be changed to be
+ * configurable in the future.
+ */
+ private var saslClient: SaslClient = Sasl.createSaslClient(Array[String](SparkSaslServer.DIGEST),
+ null, null, SparkSaslServer.SASL_DEFAULT_REALM, SparkSaslServer.SASL_PROPS,
+ new SparkSaslClientCallbackHandler(securityMgr))
+
+ /**
+ * Used to initiate SASL handshake with server.
+ * @return response to challenge if needed
+ */
+ def firstToken(): Array[Byte] = {
+ synchronized {
+ val saslToken: Array[Byte] =
+ if (saslClient != null && saslClient.hasInitialResponse()) {
+ logDebug("has initial response")
+ saslClient.evaluateChallenge(new Array[Byte](0))
+ } else {
+ new Array[Byte](0)
+ }
+ saslToken
+ }
+ }
+
+ /**
+ * Determines whether the authentication exchange has completed.
+ * @return true is complete, otherwise false
+ */
+ def isComplete(): Boolean = {
+ synchronized {
+ if (saslClient != null) saslClient.isComplete() else false
+ }
+ }
+
+ /**
+ * Respond to server's SASL token.
+ * @param saslTokenMessage contains server's SASL token
+ * @return client's response SASL token
+ */
+ def saslResponse(saslTokenMessage: Array[Byte]): Array[Byte] = {
+ synchronized {
+ if (saslClient != null) saslClient.evaluateChallenge(saslTokenMessage) else new Array[Byte](0)
+ }
+ }
+
+ /**
+ * Disposes of any system resources or security-sensitive information the
+ * SaslClient might be using.
+ */
+ def dispose() {
+ synchronized {
+ if (saslClient != null) {
+ try {
+ saslClient.dispose()
+ } catch {
+ case e: SaslException => // ignored
+ } finally {
+ saslClient = null
+ }
+ }
+ }
+ }
+
+ /**
+ * Implementation of javax.security.auth.callback.CallbackHandler
+ * that works with share secrets.
+ */
+ private class SparkSaslClientCallbackHandler(securityMgr: SecurityManager) extends
+ CallbackHandler {
+
+ private val userName: String =
+ SparkSaslServer.encodeIdentifier(securityMgr.getSaslUser().getBytes())
+ private val secretKey = securityMgr.getSecretKey()
+ private val userPassword: Array[Char] =
+ SparkSaslServer.encodePassword(if (secretKey != null) secretKey.getBytes() else "".getBytes())
+
+ /**
+ * Implementation used to respond to SASL request from the server.
+ *
+ * @param callbacks objects that indicate what credential information the
+ * server's SaslServer requires from the client.
+ */
+ override def handle(callbacks: Array[Callback]) {
+ logDebug("in the sasl client callback handler")
+ callbacks foreach {
+ case nc: NameCallback => {
+ logDebug("handle: SASL client callback: setting username: " + userName)
+ nc.setName(userName)
+ }
+ case pc: PasswordCallback => {
+ logDebug("handle: SASL client callback: setting userPassword")
+ pc.setPassword(userPassword)
+ }
+ case rc: RealmCallback => {
+ logDebug("handle: SASL client callback: setting realm: " + rc.getDefaultText())
+ rc.setText(rc.getDefaultText())
+ }
+ case cb: RealmChoiceCallback => {}
+ case cb: Callback => throw
+ new UnsupportedCallbackException(cb, "handle: Unrecognized SASL client callback")
+ }
+ }
+ }
+}
View
174 core/src/main/scala/org/apache/spark/SparkSaslServer.scala
@@ -0,0 +1,174 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import javax.security.auth.callback.Callback
+import javax.security.auth.callback.CallbackHandler
+import javax.security.auth.callback.NameCallback
+import javax.security.auth.callback.PasswordCallback
+import javax.security.auth.callback.UnsupportedCallbackException
+import javax.security.sasl.AuthorizeCallback
+import javax.security.sasl.RealmCallback
+import javax.security.sasl.Sasl
+import javax.security.sasl.SaslException
+import javax.security.sasl.SaslServer
+import scala.collection.JavaConversions.mapAsJavaMap
+import org.apache.commons.net.util.Base64
+
+/**
+ * Encapsulates SASL server logic
+ */
+private[spark] class SparkSaslServer(securityMgr: SecurityManager) extends Logging {
+
+ /**
+ * Actual SASL work done by this object from javax.security.sasl.
+ */
+ private var saslServer: SaslServer = Sasl.createSaslServer(SparkSaslServer.DIGEST, null,
+ SparkSaslServer.SASL_DEFAULT_REALM, SparkSaslServer.SASL_PROPS,
+ new SparkSaslDigestCallbackHandler(securityMgr))
+
+ /**
+ * Determines whether the authentication exchange has completed.
+ * @return true is complete, otherwise false
+ */
+ def isComplete(): Boolean = {
+ synchronized {
+ if (saslServer != null) saslServer.isComplete() else false
+ }
+ }
+
+ /**
+ * Used to respond to server SASL tokens.
+ * @param token Server's SASL token
+ * @return response to send back to the server.
+ */
+ def response(token: Array[Byte]): Array[Byte] = {
+ synchronized {
+ if (saslServer != null) saslServer.evaluateResponse(token) else new Array[Byte](0)
+ }
+ }
+
+ /**
+ * Disposes of any system resources or security-sensitive information the
+ * SaslServer might be using.
+ */
+ def dispose() {
+ synchronized {
+ if (saslServer != null) {
+ try {
+ saslServer.dispose()
+ } catch {
+ case e: SaslException => // ignore
+ } finally {
+ saslServer = null
+ }
+ }
+ }
+ }
+
+ /**
+ * Implementation of javax.security.auth.callback.CallbackHandler
+ * for SASL DIGEST-MD5 mechanism
+ */
+ private class SparkSaslDigestCallbackHandler(securityMgr: SecurityManager)
+ extends CallbackHandler {
+
+ private val userName: String =
+ SparkSaslServer.encodeIdentifier(securityMgr.getSaslUser().getBytes())
+
+ override def handle(callbacks: Array[Callback]) {
+ logDebug("In the sasl server callback handler")
+ callbacks foreach {
+ case nc: NameCallback => {
+ logDebug("handle: SASL server callback: setting username")
+ nc.setName(userName)
+ }
+ case pc: PasswordCallback => {
+ logDebug("handle: SASL server callback: setting userPassword")
+ val password: Array[Char] =
+ SparkSaslServer.encodePassword(securityMgr.getSecretKey().getBytes())
+ pc.setPassword(password)
+ }
+ case rc: RealmCallback => {
+ logDebug("handle: SASL server callback: setting realm: " + rc.getDefaultText())
+ rc.setText(rc.getDefaultText())
+ }
+ case ac: AuthorizeCallback => {
+ val authid = ac.getAuthenticationID()
+ val authzid = ac.getAuthorizationID()
+ if (authid.equals(authzid)) {
+ logDebug("set auth to true")
+ ac.setAuthorized(true)
+ } else {
+ logDebug("set auth to false")
+ ac.setAuthorized(false)
+ }
+ if (ac.isAuthorized()) {
+ logDebug("sasl server is authorized")
+ ac.setAuthorizedID(authzid)
+ }
+ }
+ case cb: Callback => throw
+ new UnsupportedCallbackException(cb, "handle: Unrecognized SASL DIGEST-MD5 Callback")
+ }
+ }
+ }
+}
+
+private[spark] object SparkSaslServer {
+
+ /**
+ * This is passed as the server name when creating the sasl client/server.
+ * This could be changed to be configurable in the future.
+ */
+ val SASL_DEFAULT_REALM = "default"
+
+ /**
+ * The authentication mechanism used here is DIGEST-MD5. This could be changed to be
+ * configurable in the future.
+ */
+ val DIGEST = "DIGEST-MD5"
+
+ /**
+ * The quality of protection is just "auth". This means that we are doing
+ * authentication only, we are not supporting integrity or privacy protection of the
+ * communication channel after authentication. This could be changed to be configurable
+ * in the future.
+ */
+ val SASL_PROPS = Map(Sasl.QOP -> "auth", Sasl.SERVER_AUTH ->"true")
+
+ /**
+ * Encode a byte[] identifier as a Base64-encoded string.
+ *
+ * @param identifier identifier to encode
+ * @return Base64-encoded string
+ */
+ def encodeIdentifier(identifier: Array[Byte]): String = {
+ new String(Base64.encodeBase64(identifier))
+ }
+
+ /**
+ * Encode a password as a base64-encoded char[] array.
+ * @param password as a byte array.
+ * @return password as a char array.
+ */
+ def encodePassword(password: Array[Byte]): Array[Char] = {
+ new String(Base64.encodeBase64(password)).toCharArray()
+ }
+}
+
View
5 core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala
@@ -60,7 +60,8 @@ abstract class Broadcast[T](val id: Long) extends Serializable {
}
private[spark]
-class BroadcastManager(val _isDriver: Boolean, conf: SparkConf) extends Logging with Serializable {
+class BroadcastManager(val _isDriver: Boolean, conf: SparkConf, securityManager: SecurityManager)
+ extends Logging with Serializable {
private var initialized = false
private var broadcastFactory: BroadcastFactory = null
@@ -78,7 +79,7 @@ class BroadcastManager(val _isDriver: Boolean, conf: SparkConf) extends Logging
Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory]
// Initialize appropriate BroadcastFactory and BroadcastObject
- broadcastFactory.initialize(isDriver, conf)
+ broadcastFactory.initialize(isDriver, conf, securityManager)
initialized = true
}
View
3  core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala
@@ -16,6 +16,7 @@
*/
package org.apache.spark.broadcast
+import org.apache.spark.SecurityManager
import org.apache.spark.SparkConf
@@ -26,7 +27,7 @@ import org.apache.spark.SparkConf
* entire Spark job.
*/
trait BroadcastFactory {
- def initialize(isDriver: Boolean, conf: SparkConf): Unit
+ def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager): Unit
def newBroadcast[T](value: T, isLocal: Boolean, id: Long): Broadcast[T]
def stop(): Unit
}
View
32 core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
@@ -18,13 +18,13 @@
package org.apache.spark.broadcast
import java.io.{File, FileOutputStream, ObjectInputStream, OutputStream}
-import java.net.URL
+import java.net.{URL, URLConnection, URI}
import java.util.concurrent.TimeUnit
import it.unimi.dsi.fastutil.io.FastBufferedInputStream
import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
-import org.apache.spark.{HttpServer, Logging, SparkConf, SparkEnv}
+import org.apache.spark.{SparkConf, HttpServer, Logging, SecurityManager, SparkEnv}
import org.apache.spark.io.CompressionCodec
import org.apache.spark.storage.{BroadcastBlockId, StorageLevel}
import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashSet, Utils}
@@ -67,7 +67,9 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea
* A [[BroadcastFactory]] implementation that uses a HTTP server as the broadcast medium.
*/
class HttpBroadcastFactory extends BroadcastFactory {
- def initialize(isDriver: Boolean, conf: SparkConf) { HttpBroadcast.initialize(isDriver, conf) }
+ def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) {
+ HttpBroadcast.initialize(isDriver, conf, securityMgr)
+ }
def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
new HttpBroadcast[T](value_, isLocal, id)
@@ -83,6 +85,7 @@ private object HttpBroadcast extends Logging {
private var bufferSize: Int = 65536
private var serverUri: String = null
private var server: HttpServer = null
+ private var securityManager: SecurityManager = null
// TODO: This shouldn't be a global variable so that multiple SparkContexts can coexist
private val files = new TimeStampedHashSet[String]
@@ -92,11 +95,12 @@ private object HttpBroadcast extends Logging {
private var compressionCodec: CompressionCodec = null
- def initialize(isDriver: Boolean, conf: SparkConf) {
+ def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) {
synchronized {
if (!initialized) {
bufferSize = conf.getInt("spark.buffer.size", 65536)
compress = conf.getBoolean("spark.broadcast.compress", true)
+ securityManager = securityMgr
if (isDriver) {
createServer(conf)
conf.set("spark.httpBroadcast.uri", serverUri)
@@ -126,7 +130,7 @@ private object HttpBroadcast extends Logging {
private def createServer(conf: SparkConf) {
broadcastDir = Utils.createTempDir(Utils.getLocalDir(conf))
- server = new HttpServer(broadcastDir)
+ server = new HttpServer(broadcastDir, securityManager)
server.start()
serverUri = server.uri
logInfo("Broadcast server started at " + serverUri)
@@ -149,11 +153,23 @@ private object HttpBroadcast extends Logging {
}
def read[T](id: Long): T = {
+ logDebug("broadcast read server: " + serverUri + " id: broadcast-" + id)
val url = serverUri + "/" + BroadcastBlockId(id).name
+
+ var uc: URLConnection = null
+ if (securityManager.isAuthenticationEnabled()) {
+ logDebug("broadcast security enabled")
+ val newuri = Utils.constructURIForAuthentication(new URI(url), securityManager)
+ uc = newuri.toURL().openConnection()
+ uc.setAllowUserInteraction(false)
+ } else {
+ logDebug("broadcast not using security")
+ uc = new URL(url).openConnection()
+ }
+
val in = {
- val httpConnection = new URL(url).openConnection()
- httpConnection.setReadTimeout(httpReadTimeout)
- val inputStream = httpConnection.getInputStream
+ uc.setReadTimeout(httpReadTimeout)
+ val inputStream = uc.getInputStream();
if (compress) {
compressionCodec.compressedInputStream(inputStream)
} else {
View
4 core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
@@ -241,7 +241,9 @@ private[spark] case class TorrentInfo(
*/
class TorrentBroadcastFactory extends BroadcastFactory {
- def initialize(isDriver: Boolean, conf: SparkConf) { TorrentBroadcast.initialize(isDriver, conf) }
+ def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) {
+ TorrentBroadcast.initialize(isDriver, conf)
+ }
def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
new TorrentBroadcast[T](value_, isLocal, id)
View
4 core/src/main/scala/org/apache/spark/deploy/Client.scala
@@ -26,7 +26,7 @@ import akka.pattern.ask
import akka.remote.{AssociationErrorEvent, DisassociatedEvent, RemotingLifecycleEvent}
import org.apache.log4j.{Level, Logger}
-import org.apache.spark.{Logging, SparkConf}
+import org.apache.spark.{Logging, SecurityManager, SparkConf}
import org.apache.spark.deploy.DeployMessages._
import org.apache.spark.deploy.master.{DriverState, Master}
import org.apache.spark.util.{AkkaUtils, Utils}
@@ -141,7 +141,7 @@ object Client {
// TODO: See if we can initialize akka so return messages are sent back using the same TCP
// flow. Else, this (sadly) requires the DriverClient be routable from the Master.
val (actorSystem, _) = AkkaUtils.createActorSystem(
- "driverClient", Utils.localHostName(), 0, false, conf)
+ "driverClient", Utils.localHostName(), 0, false, conf, new SecurityManager(conf))
actorSystem.actorOf(Props(classOf[ClientActor], driverArgs, conf))
View
10 core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
@@ -21,6 +21,7 @@ import java.security.PrivilegedExceptionAction
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.mapred.JobConf
+import org.apache.hadoop.security.Credentials
import org.apache.hadoop.security.UserGroupInformation
import org.apache.spark.{SparkContext, SparkException}
@@ -65,6 +66,15 @@ class SparkHadoopUtil {
def addCredentials(conf: JobConf) {}
def isYarnMode(): Boolean = { false }
+
+ def getCurrentUserCredentials(): Credentials = { null }
+
+ def addCurrentUserCredentials(creds: Credentials) {}
+
+ def addSecretKeyToUserCredentials(key: String, secret: String) {}
+
+ def getSecretKeyFromUserCredentials(key: String): Array[Byte] = { null }
+
}
object SparkHadoopUtil {
View
5 core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala
@@ -17,7 +17,7 @@
package org.apache.spark.deploy.client
-import org.apache.spark.{Logging, SparkConf}
+import org.apache.spark.{SecurityManager, SparkConf, Logging}
import org.apache.spark.deploy.{ApplicationDescription, Command}
import org.apache.spark.util.{AkkaUtils, Utils}
@@ -45,8 +45,9 @@ private[spark] object TestClient {
def main(args: Array[String]) {
val url = args(0)
+ val conf = new SparkConf
val (actorSystem, port) = AkkaUtils.createActorSystem("spark", Utils.localIpAddress, 0,
- conf = new SparkConf)
+ conf = conf, securityManager = new SecurityManager(conf))
val desc = new ApplicationDescription(
"TestClient", Some(1), 512, Command("spark.deploy.client.TestExecutor", Seq(), Map()),
Some("dummy-spark-home"), "ignored")
View
17 core/src/main/scala/org/apache/spark/deploy/master/Master.scala
@@ -30,7 +30,7 @@ import akka.pattern.ask
import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent}
import akka.serialization.SerializationExtension
-import org.apache.spark.{Logging, SparkConf, SparkException}
+import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException}
import org.apache.spark.deploy.{ApplicationDescription, DriverDescription, ExecutorState}
import org.apache.spark.deploy.DeployMessages._
import org.apache.spark.deploy.master.DriverState.DriverState
@@ -39,7 +39,8 @@ import org.apache.spark.deploy.master.ui.MasterWebUI
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.util.{AkkaUtils, Utils}
-private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Actor with Logging {
+private[spark] class Master(host: String, port: Int, webUiPort: Int,
+ val securityMgr: SecurityManager) extends Actor with Logging {
import context.dispatcher // to use Akka's scheduler.schedule()
val conf = new SparkConf
@@ -70,8 +71,9 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
Utils.checkHost(host, "Expected hostname")
- val masterMetricsSystem = MetricsSystem.createMetricsSystem("master", conf)
- val applicationMetricsSystem = MetricsSystem.createMetricsSystem("applications", conf)
+ val masterMetricsSystem = MetricsSystem.createMetricsSystem("master", conf, securityMgr)
+ val applicationMetricsSystem = MetricsSystem.createMetricsSystem("applications", conf,
+ securityMgr)
val masterSource = new MasterSource(this)
val webUi = new MasterWebUI(this, webUiPort)
@@ -711,8 +713,11 @@ private[spark] object Master {
def startSystemAndActor(host: String, port: Int, webUiPort: Int, conf: SparkConf)
: (ActorSystem, Int, Int) =
{
- val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port, conf = conf)
- val actor = actorSystem.actorOf(Props(classOf[Master], host, boundPort, webUiPort), actorName)
+ val securityMgr = new SecurityManager(conf)
+ val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port, conf = conf,
+ securityManager = securityMgr)
+ val actor = actorSystem.actorOf(Props(classOf[Master], host, boundPort, webUiPort,
+ securityMgr), actorName)
val timeout = AkkaUtils.askTimeout(conf)
val respFuture = actor.ask(RequestWebUIPort)(timeout)
val resp = Await.result(respFuture, timeout).asInstanceOf[WebUIPortResponse]
View
25 core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala
@@ -18,8 +18,8 @@
package org.apache.spark.deploy.master.ui
import javax.servlet.http.HttpServletRequest
-
-import org.eclipse.jetty.server.{Handler, Server}
+import org.eclipse.jetty.server.Server
+import org.eclipse.jetty.servlet.ServletContextHandler
import org.apache.spark.Logging
import org.apache.spark.deploy.master.Master
@@ -46,7 +46,7 @@ class MasterWebUI(val master: Master, requestedPort: Int) extends Logging {
def start() {
try {
- val (srv, bPort) = JettyUtils.startJettyServer(host, port, handlers)
+ val (srv, bPort) = JettyUtils.startJettyServer(host, port, handlers, master.conf)
server = Some(srv)
boundPort = Some(bPort)
logInfo("Started Master web UI at http://%s:%d".format(host, boundPort.get))
@@ -60,12 +60,17 @@ class MasterWebUI(val master: Master, requestedPort: Int) extends Logging {
val metricsHandlers = master.masterMetricsSystem.getServletHandlers ++
master.applicationMetricsSystem.getServletHandlers
- val handlers = metricsHandlers ++ Array[(String, Handler)](
- ("/static", createStaticHandler(MasterWebUI.STATIC_RESOURCE_DIR)),
- ("/app/json", (request: HttpServletRequest) => applicationPage.renderJson(request)),
- ("/app", (request: HttpServletRequest) => applicationPage.render(request)),
- ("/json", (request: HttpServletRequest) => indexPage.renderJson(request)),
- ("*", (request: HttpServletRequest) => indexPage.render(request))
+ val handlers = metricsHandlers ++ Seq[ServletContextHandler](
+ createStaticHandler(MasterWebUI.STATIC_RESOURCE_DIR, "/static/*"),
+ createServletHandler("/app/json",
+ createServlet((request: HttpServletRequest) => applicationPage.renderJson(request),
+ master.securityMgr)),
+ createServletHandler("/app", createServlet((request: HttpServletRequest) => applicationPage
+ .render(request), master.securityMgr)),
+ createServletHandler("/json", createServlet((request: HttpServletRequest) => indexPage
+ .renderJson(request), master.securityMgr)),
+ createServletHandler("*", createServlet((request: HttpServletRequest) => indexPage.render
+ (request), master.securityMgr))
)
def stop() {
@@ -74,5 +79,5 @@ class MasterWebUI(val master: Master, requestedPort: Int) extends Logging {
}
private[spark] object MasterWebUI {
- val STATIC_RESOURCE_DIR = "org/apache/spark/ui/static"
+ val STATIC_RESOURCE_DIR = "org/apache/spark/ui"
}
View
5 core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala
@@ -19,7 +19,7 @@ package org.apache.spark.deploy.worker
import akka.actor._
-import org.apache.spark.SparkConf
+import org.apache.spark.{SecurityManager, SparkConf}
import org.apache.spark.util.{AkkaUtils, Utils}
/**
@@ -29,8 +29,9 @@ object DriverWrapper {
def main(args: Array[String]) {
args.toList match {
case workerUrl :: mainClass :: extraArgs =>
+ val conf = new SparkConf()
val (actorSystem, _) = AkkaUtils.createActorSystem("Driver",
- Utils.localHostName(), 0, false, new SparkConf())
+ Utils.localHostName(), 0, false, conf, new SecurityManager(conf))
actorSystem.actorOf(Props(classOf[WorkerWatcher], workerUrl), name = "workerWatcher")
// Delegate to supplied main class
View
12 core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
@@ -27,7 +27,7 @@ import scala.concurrent.duration._
import akka.actor._
import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent}
-import org.apache.spark.{Logging, SparkConf, SparkException}
+import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException}
import org.apache.spark.deploy.{ExecutorDescription, ExecutorState}
import org.apache.spark.deploy.DeployMessages._
import org.apache.spark.deploy.master.{DriverState, Master}
@@ -48,7 +48,8 @@ private[spark] class Worker(
actorSystemName: String,
actorName: String,
workDirPath: String = null,
- val conf: SparkConf)
+ val conf: SparkConf,
+ val securityMgr: SecurityManager)
extends Actor with Logging {
import context.dispatcher
@@ -91,7 +92,7 @@ private[spark] class Worker(
var coresUsed = 0
var memoryUsed = 0
- val metricsSystem = MetricsSystem.createMetricsSystem("worker", conf)
+ val metricsSystem = MetricsSystem.createMetricsSystem("worker", conf, securityMgr)
val workerSource = new WorkerSource(this)
def coresFree: Int = cores - coresUsed
@@ -347,10 +348,11 @@ private[spark] object Worker {
val conf = new SparkConf
val systemName = "sparkWorker" + workerNumber.map(_.toString).getOrElse("")
val actorName = "Worker"
+ val securityMgr = new SecurityManager(conf)
val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port,
- conf = conf)
+ conf = conf, securityManager = securityMgr)
actorSystem.actorOf(Props(classOf[Worker], host, boundPort, webUiPort, cores, memory,
- masterUrls, systemName, actorName, workDir, conf), name = actorName)
+ masterUrls, systemName, actorName, workDir, conf, securityMgr), name = actorName)
(actorSystem, boundPort)
}
View
26 core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala
@@ -19,8 +19,8 @@ package org.apache.spark.deploy.worker.ui
import java.io.File
import javax.servlet.http.HttpServletRequest
-
-import org.eclipse.jetty.server.{Handler, Server}
+import org.eclipse.jetty.server.Server
+import org.eclipse.jetty.servlet.ServletContextHandler
import org.apache.spark.Logging
import org.apache.spark.deploy.worker.Worker
@@ -33,7 +33,7 @@ import org.apache.spark.util.{AkkaUtils, Utils}
*/
private[spark]
class WorkerWebUI(val worker: Worker, val workDir: File, requestedPort: Option[Int] = None)
- extends Logging {
+ extends Logging {
val timeout = AkkaUtils.askTimeout(worker.conf)
val host = Utils.localHostName()
val port = requestedPort.getOrElse(
@@ -46,17 +46,21 @@ class WorkerWebUI(val worker: Worker, val workDir: File, requestedPort: Option[I
val metricsHandlers = worker.metricsSystem.getServletHandlers
- val handlers = metricsHandlers ++ Array[(String, Handler)](
- ("/static", createStaticHandler(WorkerWebUI.STATIC_RESOURCE_DIR)),
- ("/log", (request: HttpServletRequest) => log(request)),
- ("/logPage", (request: HttpServletRequest) => logPage(request)),
- ("/json", (request: HttpServletRequest) => indexPage.renderJson(request)),
- ("*", (request: HttpServletRequest) => indexPage.render(request))
+ val handlers = metricsHandlers ++ Seq[ServletContextHandler](
+ createStaticHandler(WorkerWebUI.STATIC_RESOURCE_BASE, "/static/*"),
+ createServletHandler("/log", createServlet((request: HttpServletRequest) => log(request),
+ worker.securityMgr)),
+ createServletHandler("/logPage", createServlet((request: HttpServletRequest) => logPage
+ (request), worker.securityMgr)),
+ createServletHandler("/json", createServlet((request: HttpServletRequest) => indexPage
+ .renderJson(request), worker.securityMgr)),
+ createServletHandler("*", createServlet((request: HttpServletRequest) => indexPage.render
+ (request), worker.securityMgr))
)
def start() {
try {
- val (srv, bPort) = JettyUtils.startJettyServer(host, port, handlers)
+ val (srv, bPort) = JettyUtils.startJettyServer(host, port, handlers, worker.conf)
server = Some(srv)
boundPort = Some(bPort)
logInfo("Started Worker web UI at http://%s:%d".format(host, bPort))
@@ -198,6 +202,6 @@ class WorkerWebUI(val worker: Worker, val workDir: File, requestedPort: Option[I
}
private[spark] object WorkerWebUI {
- val STATIC_RESOURCE_DIR = "org/apache/spark/ui/static"
+ val STATIC_RESOURCE_BASE = "org/apache/spark/ui"
val DEFAULT_PORT="8081"
}
View
5 core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
@@ -22,7 +22,7 @@ import java.nio.ByteBuffer
import akka.actor._
import akka.remote._
-import org.apache.spark.{Logging, SparkConf}
+import org.apache.spark.{SecurityManager, SparkConf, Logging}
import org.apache.spark.TaskState.TaskState
import org.apache.spark.deploy.worker.WorkerWatcher
import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._
@@ -97,10 +97,11 @@ private[spark] object CoarseGrainedExecutorBackend {
// Debug code
Utils.checkHost(hostname)
+ val conf = new SparkConf
// Create a new ActorSystem to run the backend, because we can't create a SparkEnv / Executor
// before getting started with all our system properties, etc
val (actorSystem, boundPort) = AkkaUtils.createActorSystem("sparkExecutor", hostname, 0,
- indestructible = true, conf = new SparkConf)
+ indestructible = true, conf = conf, new SecurityManager(conf))
// set it
val sparkHostPort = hostname + ":" + boundPort
actorSystem.actorOf(
View
15 core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -69,11 +69,6 @@ private[spark] class Executor(
conf.set("spark.local.dir", getYarnLocalDirs())
}
- // Create our ClassLoader and set it on this thread
- private val urlClassLoader = createClassLoader()
- private val replClassLoader = addReplClassLoaderIfNeeded(urlClassLoader)
- Thread.currentThread.setContextClassLoader(replClassLoader)
-
if (!isLocal) {
// Setup an uncaught exception handler for non-local mode.
// Make any thread terminations due to uncaught exceptions kill the entire
@@ -117,6 +112,12 @@ private[spark] class Executor(
}
}
+ // Create our ClassLoader and set it on this thread
+ // do this after SparkEnv creation so can access the SecurityManager
+ private val urlClassLoader = createClassLoader()
+ private val replClassLoader = addReplClassLoaderIfNeeded(urlClassLoader)
+ Thread.currentThread.setContextClassLoader(replClassLoader)
+
// Akka's message frame size. If task result is bigger than this, we use the block manager
// to send the result back.
private val akkaFrameSize = {
@@ -338,12 +339,12 @@ private[spark] class Executor(
// Fetch missing dependencies
for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) {
logInfo("Fetching " + name + " with timestamp " + timestamp)
- Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf)
+ Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf, env.securityManager)
currentFiles(name) = timestamp
}
for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) {
logInfo("Fetching " + name + " with timestamp " + timestamp)
- Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf)
+ Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf, env.securityManager)
currentJars(name) = timestamp
// Add it to our class loader
val localName = name.split("/").last
View
13 core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala
@@ -24,7 +24,7 @@ import scala.collection.mutable
import com.codahale.metrics.{Metric, MetricFilter, MetricRegistry}
-import org.apache.spark.{Logging, SparkConf}
+import org.apache.spark.{Logging, SecurityManager, SparkConf}
import org.apache.spark.metrics.sink.{MetricsServlet, Sink}
import org.apache.spark.metrics.source.Source
@@ -64,7 +64,7 @@ import org.apache.spark.metrics.source.Source
* [options] is the specific property of this source or sink.
*/
private[spark] class MetricsSystem private (val instance: String,
- conf: SparkConf) extends Logging {
+ conf: SparkConf, securityMgr: SecurityManager) extends Logging {
val confFile = conf.get("spark.metrics.conf", null)
val metricsConfig = new MetricsConfig(Option(confFile))
@@ -131,8 +131,8 @@ private[spark] class MetricsSystem private (val instance: String,
val classPath = kv._2.getProperty("class")
try {
val sink = Class.forName(classPath)
- .getConstructor(classOf[Properties], classOf[MetricRegistry])
- .newInstance(kv._2, registry)
+ .getConstructor(classOf[Properties], classOf[MetricRegistry], classOf[SecurityManager])
+ .newInstance(kv._2, registry, securityMgr)
if (kv._1 == "servlet") {
metricsServlet = Some(sink.asInstanceOf[MetricsServlet])
} else {
@@ -160,6 +160,7 @@ private[spark] object MetricsSystem {
}
}
- def createMetricsSystem(instance: String, conf: SparkConf): MetricsSystem =
- new MetricsSystem(instance, conf)
+ def createMetricsSystem(instance: String, conf: SparkConf,
+ securityMgr: SecurityManager): MetricsSystem =
+ new MetricsSystem(instance, conf, securityMgr)
}
View
4 core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala
@@ -22,9 +22,11 @@ import java.util.concurrent.TimeUnit
import com.codahale.metrics.{ConsoleReporter, MetricRegistry}
+import org.apache.spark.SecurityManager
import org.apache.spark.metrics.MetricsSystem
-class ConsoleSink(val property: Properties, val registry: MetricRegistry) extends Sink {
+class ConsoleSink(val property: Properties, val registry: MetricRegistry,
+ securityMgr: SecurityManager) extends Sink {
val CONSOLE_DEFAULT_PERIOD = 10
val CONSOLE_DEFAULT_UNIT = "SECONDS"
View
4 core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala
@@ -23,9 +23,11 @@ import java.util.concurrent.TimeUnit
import com.codahale.metrics.{CsvReporter, MetricRegistry}
+import org.apache.spark.SecurityManager
import org.apache.spark.metrics.MetricsSystem
-class CsvSink(val property: Properties, val registry: MetricRegistry) extends Sink {
+class CsvSink(val property: Properties, val registry: MetricRegistry,
+ securityMgr: SecurityManager) extends Sink {
val CSV_KEY_PERIOD = "period"
val CSV_KEY_UNIT = "unit"
val CSV_KEY_DIR = "directory"
View
4 core/src/main/scala/org/apache/spark/metrics/sink/GangliaSink.scala
@@ -24,9 +24,11 @@ import com.codahale.metrics.MetricRegistry
import com.codahale.metrics.ganglia.GangliaReporter
import info.ganglia.gmetric4j.gmetric.GMetric
+import org.apache.spark.SecurityManager
import org.apache.spark.metrics.MetricsSystem
-class GangliaSink(val property: Properties, val registry: MetricRegistry) extends Sink {
+class GangliaSink(val property: Properties, val registry: MetricRegistry,
+ securityMgr: SecurityManager) extends Sink {
val GANGLIA_KEY_PERIOD = "period"
val GANGLIA_DEFAULT_PERIOD = 10
View
4 core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala
@@ -24,9 +24,11 @@ import java.util.concurrent.TimeUnit
import com.codahale.metrics.MetricRegistry
import com.codahale.metrics.graphite.{Graphite, GraphiteReporter}
+import org.apache.spark.SecurityManager
import org.apache.spark.metrics.MetricsSystem
-class GraphiteSink(val property: Properties, val registry: MetricRegistry) extends Sink {
+class GraphiteSink(val property: Properties, val registry: MetricRegistry,
+ securityMgr: SecurityManager) extends Sink {
val GRAPHITE_DEFAULT_PERIOD = 10
val GRAPHITE_DEFAULT_UNIT = "SECONDS"
val GRAPHITE_DEFAULT_PREFIX = ""
View
5 core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala
@@ -20,8 +20,11 @@ package org.apache.spark.metrics.sink
import java.util.Properties
import com.codahale.metrics.{JmxReporter, MetricRegistry}
+import org.apache.spark.SecurityManager
+
+class JmxSink(val property: Properties, val registry: MetricRegistry,
+ securityMgr: SecurityManager) extends Sink {
-class JmxSink(val property: Properties, val registry: MetricRegistry) extends Sink {
val reporter: JmxReporter = JmxReporter.forRegistry(registry).build()
override def start() {
View
14 core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala
@@ -19,16 +19,19 @@ package org.apache.spark.metrics.sink
import java.util.Properties
import java.util.concurrent.TimeUnit
+
import javax.servlet.http.HttpServletRequest
import com.codahale.metrics.MetricRegistry
import com.codahale.metrics.json.MetricsModule
import com.fasterxml.jackson.databind.ObjectMapper
-import org.eclipse.jetty.server.Handler
+import org.eclipse.jetty.servlet.ServletContextHandler
+import org.apache.spark.SecurityManager
import org.apache.spark.ui.JettyUtils
-class MetricsServlet(val property: Properties, val registry: MetricRegistry) extends Sink {
+class MetricsServlet(val property: Properties, val registry: MetricRegistry,
+ securityMgr: SecurityManager) extends Sink {
val SERVLET_KEY_PATH = "path"
val SERVLET_KEY_SAMPLE = "sample"
@@ -42,8 +45,11 @@ class MetricsServlet(val property: Properties, val registry: MetricRegistry) ext
val mapper = new ObjectMapper().registerModule(
new MetricsModule(TimeUnit.SECONDS, TimeUnit.MILLISECONDS, servletShowSample))
- def getHandlers = Array[(String, Handler)](
- (servletPath, JettyUtils.createHandler(request => getMetricsSnapshot(request), "text/json"))
+ def getHandlers = Array[ServletContextHandler](
+ JettyUtils.createServletHandler(servletPath,
+ JettyUtils.createServlet(
+ new JettyUtils.ServletParams(request => getMetricsSnapshot(request), "text/json"),
+ securityMgr) )
)
def getMetricsSnapshot(request: HttpServletRequest): String = {
View
8 core/src/main/scala/org/apache/spark/network/BufferMessage.scala
@@ -45,9 +45,10 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId:
throw new Exception("Max chunk size is " + maxChunkSize)
}
+ val security = if (isSecurityNeg) 1 else 0
if (size == 0 && !gotChunkForSendingOnce) {
val newChunk = new MessageChunk(
- new MessageChunkHeader(typ, id, 0, 0, ackId, senderAddress), null)
+ new MessageChunkHeader(typ, id, 0, 0, ackId, security, senderAddress), null)
gotChunkForSendingOnce = true
return Some(newChunk)
}
@@ -65,7 +66,7 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId:
}
buffer.position(buffer.position + newBuffer.remaining)
val newChunk = new MessageChunk(new MessageChunkHeader(
- typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer)
+ typ, id, size, newBuffer.remaining, ackId, security, senderAddress), newBuffer)
gotChunkForSendingOnce = true
return Some(newChunk)
}
@@ -79,6 +80,7 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId:
throw new Exception("Attempting to get chunk from message with multiple data buffers")
}
val buffer = buffers(0)
+ val security = if (isSecurityNeg) 1 else 0
if (buffer.remaining > 0) {
if (buffer.remaining < chunkSize) {
throw new Exception("Not enough space in data buffer for receiving chunk")
@@ -86,7 +88,7 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId:
val newBuffer = buffer.slice().limit(chunkSize).asInstanceOf[ByteBuffer]
buffer.position(buffer.position + newBuffer.remaining)
val newChunk = new MessageChunk(new MessageChunkHeader(
- typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer)
+ typ, id, size, newBuffer.remaining, ackId, security, senderAddress), newBuffer)
return Some(newChunk)
}
None
View
61 core/src/main/scala/org/apache/spark/network/Connection.scala
@@ -17,6 +17,11 @@
package org.apache.spark.network
+import org.apache.spark._
+import org.apache.spark.SparkSaslServer
+
+import scala.collection.mutable.{HashMap, Queue, ArrayBuffer}
+
import java.net._
import java.nio._
import java.nio.channels._
@@ -27,13 +32,16 @@ import org.apache.spark._
private[spark]
abstract class Connection(val channel: SocketChannel, val selector: Selector,
- val socketRemoteConnectionManagerId: ConnectionManagerId)
+ val socketRemoteConnectionManagerId: ConnectionManagerId, val connectionId: ConnectionId)
extends Logging {
- def this(channel_ : SocketChannel, selector_ : Selector) = {
+ var sparkSaslServer: SparkSaslServer = null
+ var sparkSaslClient: SparkSaslClient = null
+
+ def this(channel_ : SocketChannel, selector_ : Selector, id_ : ConnectionId) = {
this(channel_, selector_,
ConnectionManagerId.fromSocketAddress(
- channel_.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress]))
+ channel_.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress]), id_)
}
channel.configureBlocking(false)
@@ -49,6 +57,16 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector,
val remoteAddress = getRemoteAddress()
+ /**
+ * Used to synchronize client requests: client's work-related requests must
+ * wait until SASL authentication completes.
+ */
+ private val authenticated = new Object()
+
+ def getAuthenticated(): Object = authenticated
+
+ def isSaslComplete(): Boolean
+
def resetForceReregister(): Boolean
// Read channels typically do not register for write and write does not for read
@@ -69,6 +87,16 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector,
// Will be true for ReceivingConnection, false for SendingConnection.
def changeInterestForRead(): Boolean
+ private def disposeSasl() {
+ if (sparkSaslServer != null) {
+ sparkSaslServer.dispose();
+ }
+
+ if (sparkSaslClient != null) {
+ sparkSaslClient.dispose()
+ }
+ }
+
// On receiving a write event, should we change the interest for this channel or not ?
// Will be false for ReceivingConnection, true for SendingConnection.
// Actually, for now, should not get triggered for ReceivingConnection
@@ -101,6 +129,7 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector,
k.cancel()
}
channel.close()
+ disposeSasl()
callOnCloseCallback()
}
@@ -168,8 +197,12 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector,
private[spark]
class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
- remoteId_ : ConnectionManagerId)
- extends Connection(SocketChannel.open, selector_, remoteId_) {
+ remoteId_ : ConnectionManagerId, id_ : ConnectionId)
+ extends Connection(SocketChannel.open, selector_, remoteId_, id_) {
+
+ def isSaslComplete(): Boolean = {
+ if (sparkSaslClient != null) sparkSaslClient.isComplete() else false
+ }
private class Outbox {
val messages = new Queue[Message]()
@@ -226,6 +259,7 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
data as detailed in https://github.com/mesos/spark/pull/791
*/
private var needForceReregister = false
+
val currentBuffers = new ArrayBuffer[ByteBuffer]()
/*channel.socket.setSendBufferSize(256 * 1024)*/
@@ -316,6 +350,7 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
// If we have 'seen' pending messages, then reset flag - since we handle that as
// normal registering of event (below)
if (needForceReregister && buffers.exists(_.remaining() > 0)) resetForceReregister()
+
currentBuffers ++= buffers
}
case None => {
@@ -384,8 +419,15 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
// Must be created within selector loop - else deadlock
-private[spark] class ReceivingConnection(channel_ : SocketChannel, selector_ : Selector)
- extends Connection(channel_, selector_) {
+private[spark] class ReceivingConnection(
+ channel_ : SocketChannel,
+ selector_ : Selector,
+ id_ : ConnectionId)
+ extends Connection(channel_, selector_, id_) {
+
+ def isSaslComplete(): Boolean = {
+ if (sparkSaslServer != null) sparkSaslServer.isComplete() else false
+ }
class Inbox() {
val messages = new HashMap[Int, BufferMessage]()
@@ -396,6 +438,7 @@ private[spark] class ReceivingConnection(channel_ : SocketChannel, selector_ : S
val newMessage = Message.create(header).asInstanceOf[BufferMessage]
newMessage.started = true
newMessage.startTime = System.currentTimeMillis
+ newMessage.isSecurityNeg = header.securityNeg == 1
logDebug(
"Starting to receive [" + newMessage + "] from [" + getRemoteConnectionManagerId() + "]")
messages += ((newMessage.id, newMessage))
@@ -441,7 +484,7 @@ private[spark] class ReceivingConnection(channel_ : SocketChannel, selector_ : S
val inbox = new Inbox()
val headerBuffer: ByteBuffer = ByteBuffer.allocate(MessageChunkHeader.HEADER_SIZE)
- var onReceiveCallback: (Connection , Message) => Unit = null
+ var onReceiveCallback: (Connection, Message) => Unit = null
var currentChunk: MessageChunk = null
channel.register(selector, SelectionKey.OP_READ)
@@ -516,7 +559,7 @@ private[spark] class ReceivingConnection(channel_ : SocketChannel, selector_ : S
}
}
} catch {
- case e: Exception => {
+ case e: Exception => {
logWarning("Error reading from connection to " + getRemoteConnectionManagerId(), e)
callOnExceptionCallback(e)
close()
View
34 core/src/main/scala/org/apache/spark/network/ConnectionId.scala
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network
+
+private[spark] case class ConnectionId(connectionManagerId: ConnectionManagerId, uniqId: Int) {
+ override def toString = connectionManagerId.host + "_" + connectionManagerId.port + "_" + uniqId
+}
+
+private[spark] object ConnectionId {
+
+ def createConnectionIdFromString(connectionIdString: String): ConnectionId = {
+ val res = connectionIdString.split("_").map(_.trim())
+ if (res.size != 3) {
+ throw new Exception("Error converting ConnectionId string: " + connectionIdString +
+ " to a ConnectionId Object")
+ }
+ new ConnectionId(new ConnectionManagerId(res(0), res(1).toInt), res(2).toInt)
+ }
+}
View
266 core/src/main/scala/org/apache/spark/network/ConnectionManager.scala
@@ -21,6 +21,9 @@ import java.net._
import java.nio._
import java.nio.channels._
import java.nio.channels.spi._
+import java.net._
+import java.util.concurrent.atomic.AtomicInteger
+
import java.util.concurrent.{LinkedBlockingDeque, TimeUnit, ThreadPoolExecutor}
import scala.collection.mutable.ArrayBuffer
@@ -28,13 +31,15 @@ import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
import scala.collection.mutable.SynchronizedMap
import scala.collection.mutable.SynchronizedQueue
+
import scala.concurrent.{Await, ExecutionContext, Future, Promise}
import scala.concurrent.duration._
import org.apache.spark._
-import org.apache.spark.util.Utils
+import org.apache.spark.util.{SystemClock, Utils}
-private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Logging {
+private[spark] class ConnectionManager(port: Int, conf: SparkConf,
+ securityManager: SecurityManager) extends Logging {
class MessageStatus(
val message: Message,
@@ -50,6 +55,9 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi
private val selector = SelectorProvider.provider.openSelector()
+ // default to 30 second timeout waiting for authentication
+ private val authTimeout = conf.getInt("spark.core.connection.auth.wait.timeout", 30)
+
private val handleMessageExecutor = new ThreadPoolExecutor(
conf.getInt("spark.core.connection.handler.threads.min", 20),
conf.getInt("spark.core.connection.handler.threads.max", 60),
@@ -71,6 +79,9 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi
new LinkedBlockingDeque[Runnable]())
private val serverChannel = ServerSocketChannel.open()
+ // used to track the SendingConnections waiting to do SASL negotiation
+ private val connectionsAwaitingSasl = new HashMap[ConnectionId, SendingConnection]
+ with SynchronizedMap[ConnectionId, SendingConnection]
private val connectionsByKey =
new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection]
private val connectionsById = new HashMap[ConnectionManagerId, SendingConnection]
@@ -84,6 +95,8 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi
private var onReceiveCallback: (BufferMessage, ConnectionManagerId) => Option[Message]= null
+ private val authEnabled = securityManager.isAuthenticationEnabled()
+
serverChannel.configureBlocking(false)
serverChannel.socket.setReuseAddress(true)
serverChannel.socket.setReceiveBufferSize(256 * 1024)
@@ -94,6 +107,10 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi
val id = new ConnectionManagerId(Utils.localHostName, serverChannel.socket.getLocalPort)
logInfo("Bound socket to port " + serverChannel.socket.getLocalPort() + " with id = " + id)
+ // used in combination with the ConnectionManagerId to create unique Connection ids
+ // to be able to track asynchronous messages
+ private val idCount: AtomicInteger = new AtomicInteger(1)
+
private val selectorThread = new Thread("connection-manager-thread") {
override def run() = ConnectionManager.this.run()
}
@@ -125,7 +142,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi
} finally {
writeRunnableStarted.synchronized {
writeRunnableStarted -= key
- val needReregister = register || conn.resetForceReregister()
+ val needReregister = register || conn.resetForceReregister()
if (needReregister && conn.changeInterestForWrite()) {
conn.registerInterest()
}
@@ -372,7 +389,8 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi
// accept them all in a tight loop. non blocking accept with no processing, should be fine
while (newChannel != null) {
try {
- val newConnection = new ReceivingConnection(newChannel, selector)
+ val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue)
+ val newConnection = new ReceivingConnection(newChannel, selector, newConnectionId)
newConnection.onReceive(receiveMessage)
addListeners(newConnection)
addConnection(newConnection)
@@ -406,6 +424,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi
logInfo("Removing SendingConnection to " + sendingConnectionManagerId)
connectionsById -= sendingConnectionManagerId
+ connectionsAwaitingSasl -= connection.connectionId
messageStatuses.synchronized {
messageStatuses
@@ -481,7 +500,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi
val creationTime = System.currentTimeMillis
def run() {
logDebug("Handler thread delay is " + (System.currentTimeMillis - creationTime) + " ms")
- handleMessage(connectionManagerId, message)
+ handleMessage(connectionManagerId, message, connection)
logDebug("Handling delay is " + (System.currentTimeMillis - creationTime) + " ms")
}
}
@@ -489,10 +508,133 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi
/*handleMessage(connection, message)*/
}
- private def handleMessage(connectionManagerId: ConnectionManagerId, message: Message) {
+ private def handleClientAuthentication(
+ waitingConn: SendingConnection,
+ securityMsg: SecurityMessage,
+ connectionId : ConnectionId) {
+ if (waitingConn.isSaslComplete()) {
+ logDebug("Client sasl completed for id: " + waitingConn.connectionId)
+ connectionsAwaitingSasl -= waitingConn.connectionId
+ waitingConn.getAuthenticated().synchronized {
+ waitingConn.getAuthenticated().notifyAll();
+ }
+ return
+ } else {
+ var replyToken : Array[Byte] = null
+ try {
+ replyToken = waitingConn.sparkSaslClient.saslResponse(securityMsg.getToken);
+ if (waitingConn.isSaslComplete()) {
+ logDebug("Client sasl completed after evaluate for id: " + waitingConn.connectionId)
+ connectionsAwaitingSasl -= waitingConn.connectionId
+ waitingConn.getAuthenticated().synchronized {
+ waitingConn.getAuthenticated().notifyAll()
+ }
+ return
+ }
+ var securityMsgResp = SecurityMessage.fromResponse(replyToken,
+ securityMsg.getConnectionId.toString())
+ var message = securityMsgResp.toBufferMessage
+ if (message == null) throw new Exception("Error creating security message")
+ sendSecurityMessage(waitingConn.getRemoteConnectionManagerId(), message)
+ } catch {
+ case e: Exception => {
+ logError("Error handling sasl client authentication", e)
+ waitingConn.close()
+ throw new Exception("Error evaluating sasl response: " + e)
+ }
+ }
+ }
+ }
+
+ private def handleServerAuthentication(
+ connection: Connection,
+ securityMsg: SecurityMessage,
+ connectionId: ConnectionId) {
+ if (!connection.isSaslComplete()) {
+ logDebug("saslContext not established")
+ var replyToken : Array[Byte] = null
+ try {
+ connection.synchronized {
+ if (connection.sparkSaslServer == null) {
+ logDebug("Creating sasl Server")
+ connection.sparkSaslServer = new SparkSaslServer(securityManager)
+ }
+ }
+ replyToken = connection.sparkSaslServer.response(securityMsg.getToken)
+ if (connection.isSaslComplete()) {
+ logDebug("Server sasl completed: " + connection.connectionId)
+ } else {
+ logDebug("Server sasl not completed: " + connection.connectionId)
+ }
+ if (replyToken != null) {
+ var securityMsgResp = SecurityMessage.fromResponse(replyToken,
+ securityMsg.getConnectionId)
+ var message = securityMsgResp.toBufferMessage
+ if (message == null) throw new Exception("Error creating security Message")
+ sendSecurityMessage(connection.getRemoteConnectionManagerId(), message)
+ }
+ } catch {
+ case e: Exception => {
+ logError("Error in server auth negotiation: " + e)
+ // It would probably be better to send an error message telling other side auth failed
+ // but for now just close
+ connection.close()