Skip to content

Commit

Permalink
[SPARK-6229] Add SASL encryption to network library.
Browse files Browse the repository at this point in the history
There are two main parts of this change:

- Extending the bootstrap mechanism in the network library to add a server-side
  bootstrap (which works a little bit differently than the client-side bootstrap), and
  to allow the  bootstraps to modify the underlying channel.

- Use SASL to encrypt data going through the RPC channel.

The second item requires some non-optimal code to be able to work around the
fact that the outbound path in netty is not thread-safe, and ordering is very important
when encryption is in the picture.

A lot of the changes outside the network/common library are just to adjust to the
changed API for initializing the RPC server.

Author: Marcelo Vanzin <vanzin@cloudera.com>

Closes apache#5377 from vanzin/SPARK-6229 and squashes the following commits:

ff01966 [Marcelo Vanzin] Use fancy new size config style.
be53f32 [Marcelo Vanzin] Merge branch 'master' into SPARK-6229
47d4aff [Marcelo Vanzin] Merge branch 'master' into SPARK-6229
7a2a805 [Marcelo Vanzin] Clean up some unneeded changes.
2f92237 [Marcelo Vanzin] Add comment.
67bb0c6 [Marcelo Vanzin] Revert "Avoid exposing ByteArrayWritableChannel outside of test code."
065f684 [Marcelo Vanzin] Add test to verify chunking.
3d1695d [Marcelo Vanzin] Minor cleanups.
73cff0e [Marcelo Vanzin] Skip bytes in decode path too.
318ad23 [Marcelo Vanzin] Avoid exposing ByteArrayWritableChannel outside of test code.
346f829 [Marcelo Vanzin] Avoid trip through channel selector by not reporting 0 bytes written.
a4a5938 [Marcelo Vanzin] Review feedback.
4797519 [Marcelo Vanzin] Remove unused import.
9908ada [Marcelo Vanzin] Fix test, SASL backend disposal.
7fe1489 [Marcelo Vanzin] Add a test that makes sure encryption is actually enabled.
adb6f9d [Marcelo Vanzin] Review feedback.
cf2a605 [Marcelo Vanzin] Clean up some code.
8584323 [Marcelo Vanzin] Fix a comment.
e98bc55 [Marcelo Vanzin] Add option to only allow encrypted connections to the server.
dad42fc [Marcelo Vanzin] Make encryption thread-safe, less memory-intensive.
b00999a [Marcelo Vanzin] Consolidate ByteArrayWritableChannel, fix SASL code to match master changes.
b923cae [Marcelo Vanzin] Make SASL encryption handler thread-safe, handle FileRegion messages.
39539a7 [Marcelo Vanzin] Add config option to enable SASL encryption.
351a86f [Marcelo Vanzin] Add SASL encryption to network library.
fbe6ccb [Marcelo Vanzin] Add TransportServerBootstrap, make SASL code use it.
  • Loading branch information
Marcelo Vanzin authored and rxin committed May 2, 2015
1 parent 8f50a07 commit 38d4e9e
Show file tree
Hide file tree
Showing 27 changed files with 1,070 additions and 106 deletions.
17 changes: 15 additions & 2 deletions core/src/main/scala/org/apache/spark/SecurityManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,13 @@ import org.apache.spark.util.Utils
* authorization. If not filter is in place the user is generally null and no authorization
* can take place.
*
* Connection encryption (SSL) configuration is organized hierarchically. The user can configure
* the default SSL settings which will be used for all the supported communication protocols unless
* When authentication is being used, encryption can also be enabled by setting the option
* spark.authenticate.enableSaslEncryption to true. This is only supported by communication
* channels that use the network-common library, and can be used as an alternative to SSL in those
* cases.
*
* SSL can be used for encryption for certain communication channels. The user can configure the
* default SSL settings which will be used for all the supported communication protocols unless
* they are overwritten by protocol specific settings. This way the user can easily provide the
* common settings for all the protocols without disabling the ability to configure each one
* individually.
Expand Down Expand Up @@ -412,6 +417,14 @@ private[spark] class SecurityManager(sparkConf: SparkConf)
*/
def isAuthenticationEnabled(): Boolean = authOn

/**
* Checks whether SASL encryption should be enabled.
* @return Whether to enable SASL encryption when connecting to services that support it.
*/
def isSaslEncryptionEnabled(): Boolean = {
sparkConf.getBoolean("spark.authenticate.enableSaslEncryption", false)
}

/**
* Gets the user used for authenticating HTTP connections.
* For now use a single hardcoded user.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@ package org.apache.spark.deploy

import java.util.concurrent.CountDownLatch

import scala.collection.JavaConversions._

import org.apache.spark.{Logging, SparkConf, SecurityManager}
import org.apache.spark.network.TransportContext
import org.apache.spark.network.netty.SparkTransportConf
import org.apache.spark.network.sasl.SaslRpcHandler
import org.apache.spark.network.sasl.SaslServerBootstrap
import org.apache.spark.network.server.TransportServer
import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler
import org.apache.spark.util.Utils
Expand All @@ -44,10 +46,7 @@ class ExternalShuffleService(sparkConf: SparkConf, securityManager: SecurityMana

private val transportConf = SparkTransportConf.fromSparkConf(sparkConf, numUsableCores = 0)
private val blockHandler = new ExternalShuffleBlockHandler(transportConf)
private val transportContext: TransportContext = {
val handler = if (useSasl) new SaslRpcHandler(blockHandler, securityManager) else blockHandler
new TransportContext(transportConf, handler)
}
private val transportContext: TransportContext = new TransportContext(transportConf, blockHandler)

private var server: TransportServer = _

Expand All @@ -62,7 +61,13 @@ class ExternalShuffleService(sparkConf: SparkConf, securityManager: SecurityMana
def start() {
require(server == null, "Shuffle server already started")
logInfo(s"Starting shuffle service on port $port with useSasl = $useSasl")
server = transportContext.createServer(port)
val bootstraps =
if (useSasl) {
Seq(new SaslServerBootstrap(transportConf, securityManager))
} else {
Nil
}
server = transportContext.createServer(port, bootstraps)
}

def stop() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.{SecurityManager, SparkConf}
import org.apache.spark.network._
import org.apache.spark.network.buffer.ManagedBuffer
import org.apache.spark.network.client.{TransportClientBootstrap, RpcResponseCallback, TransportClientFactory}
import org.apache.spark.network.sasl.{SaslRpcHandler, SaslClientBootstrap}
import org.apache.spark.network.sasl.{SaslClientBootstrap, SaslServerBootstrap}
import org.apache.spark.network.server._
import org.apache.spark.network.shuffle.{RetryingBlockFetcher, BlockFetchingListener, OneForOneBlockFetcher}
import org.apache.spark.network.shuffle.protocol.UploadBlock
Expand All @@ -49,18 +49,18 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage
private[this] var appId: String = _

override def init(blockDataManager: BlockDataManager): Unit = {
val (rpcHandler: RpcHandler, bootstrap: Option[TransportClientBootstrap]) = {
val nettyRpcHandler = new NettyBlockRpcServer(serializer, blockDataManager)
if (!authEnabled) {
(nettyRpcHandler, None)
} else {
(new SaslRpcHandler(nettyRpcHandler, securityManager),
Some(new SaslClientBootstrap(transportConf, conf.getAppId, securityManager)))
}
val rpcHandler = new NettyBlockRpcServer(serializer, blockDataManager)
var serverBootstrap: Option[TransportServerBootstrap] = None
var clientBootstrap: Option[TransportClientBootstrap] = None
if (authEnabled) {
serverBootstrap = Some(new SaslServerBootstrap(transportConf, securityManager))
clientBootstrap = Some(new SaslClientBootstrap(transportConf, conf.getAppId, securityManager,
securityManager.isSaslEncryptionEnabled()))
}
transportContext = new TransportContext(transportConf, rpcHandler)
clientFactory = transportContext.createClientFactory(bootstrap.toList)
server = transportContext.createServer(conf.getInt("spark.blockManager.port", 0))
clientFactory = transportContext.createClientFactory(clientBootstrap.toList)
server = transportContext.createServer(conf.getInt("spark.blockManager.port", 0),
serverBootstrap.toList)
appId = conf.getAppId
logInfo("Server created on " + server.getPort)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -656,7 +656,7 @@ private[nio] class ConnectionManager(
connection.synchronized {
if (connection.sparkSaslServer == null) {
logDebug("Creating sasl Server")
connection.sparkSaslServer = new SparkSaslServer(conf.getAppId, securityManager)
connection.sparkSaslServer = new SparkSaslServer(conf.getAppId, securityManager, false)
}
}
replyToken = connection.sparkSaslServer.response(securityMsg.getToken)
Expand Down Expand Up @@ -800,7 +800,7 @@ private[nio] class ConnectionManager(
if (!conn.isSaslComplete()) {
conn.synchronized {
if (conn.sparkSaslClient == null) {
conn.sparkSaslClient = new SparkSaslClient(conf.getAppId, securityManager)
conn.sparkSaslClient = new SparkSaslClient(conf.getAppId, securityManager, false)
var firstResponse: Array[Byte] = null
try {
firstResponse = conn.sparkSaslClient.firstToken()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ private[spark] class BlockManager(
// standard BlockTransferService to directly connect to other Executors.
private[spark] val shuffleClient = if (externalShuffleServiceEnabled) {
val transConf = SparkTransportConf.fromSparkConf(conf, numUsableCores)
new ExternalShuffleClient(transConf, securityManager, securityManager.isAuthenticationEnabled())
new ExternalShuffleClient(transConf, securityManager, securityManager.isAuthenticationEnabled(),
securityManager.isSaslEncryptionEnabled())
} else {
blockTransferService
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import org.apache.spark.network.server.TransportChannelHandler;
import org.apache.spark.network.server.TransportRequestHandler;
import org.apache.spark.network.server.TransportServer;
import org.apache.spark.network.server.TransportServerBootstrap;
import org.apache.spark.network.util.NettyUtils;
import org.apache.spark.network.util.TransportConf;

Expand Down Expand Up @@ -82,27 +83,40 @@ public TransportClientFactory createClientFactory() {
}

/** Create a server which will attempt to bind to a specific port. */
public TransportServer createServer(int port) {
return new TransportServer(this, port);
public TransportServer createServer(int port, List<TransportServerBootstrap> bootstraps) {
return new TransportServer(this, port, rpcHandler, bootstraps);
}

/** Creates a new server, binding to any available ephemeral port. */
public TransportServer createServer(List<TransportServerBootstrap> bootstraps) {
return createServer(0, bootstraps);
}

public TransportServer createServer() {
return new TransportServer(this, 0);
return createServer(0, Lists.<TransportServerBootstrap>newArrayList());
}

public TransportChannelHandler initializePipeline(SocketChannel channel) {
return initializePipeline(channel, rpcHandler);
}

/**
* Initializes a client or server Netty Channel Pipeline which encodes/decodes messages and
* has a {@link org.apache.spark.network.server.TransportChannelHandler} to handle request or
* response messages.
*
* @param channel The channel to initialize.
* @param channelRpcHandler The RPC handler to use for the channel.
*
* @return Returns the created TransportChannelHandler, which includes a TransportClient that can
* be used to communicate on this channel. The TransportClient is directly associated with a
* ChannelHandler to ensure all users of the same channel get the same TransportClient object.
*/
public TransportChannelHandler initializePipeline(SocketChannel channel) {
public TransportChannelHandler initializePipeline(
SocketChannel channel,
RpcHandler channelRpcHandler) {
try {
TransportChannelHandler channelHandler = createChannelHandler(channel);
TransportChannelHandler channelHandler = createChannelHandler(channel, channelRpcHandler);
channel.pipeline()
.addLast("encoder", encoder)
.addLast("frameDecoder", NettyUtils.createFrameDecoder())
Expand All @@ -123,7 +137,7 @@ public TransportChannelHandler initializePipeline(SocketChannel channel) {
* ResponseMessages. The channel is expected to have been successfully created, though certain
* properties (such as the remoteAddress()) may not be available yet.
*/
private TransportChannelHandler createChannelHandler(Channel channel) {
private TransportChannelHandler createChannelHandler(Channel channel, RpcHandler rpcHandler) {
TransportResponseHandler responseHandler = new TransportResponseHandler(channel);
TransportClient client = new TransportClient(channel, responseHandler);
TransportRequestHandler requestHandler = new TransportRequestHandler(channel, client,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.network.client;

import io.netty.channel.Channel;

/**
* A bootstrap which is executed on a TransportClient before it is returned to the user.
* This enables an initial exchange of information (e.g., SASL authentication tokens) on a once-per-
Expand All @@ -28,5 +30,5 @@
*/
public interface TransportClientBootstrap {
/** Performs the bootstrapping operation, throwing an exception on failure. */
public void doBootstrap(TransportClient client) throws RuntimeException;
void doBootstrap(TransportClient client, Channel channel) throws RuntimeException;
}
Original file line number Diff line number Diff line change
Expand Up @@ -172,12 +172,14 @@ private TransportClient createClient(InetSocketAddress address) throws IOExcepti
.option(ChannelOption.ALLOCATOR, pooledAllocator);

final AtomicReference<TransportClient> clientRef = new AtomicReference<TransportClient>();
final AtomicReference<Channel> channelRef = new AtomicReference<Channel>();

bootstrap.handler(new ChannelInitializer<SocketChannel>() {
@Override
public void initChannel(SocketChannel ch) {
TransportChannelHandler clientHandler = context.initializePipeline(ch);
clientRef.set(clientHandler.getClient());
channelRef.set(ch);
}
});

Expand All @@ -192,14 +194,15 @@ public void initChannel(SocketChannel ch) {
}

TransportClient client = clientRef.get();
Channel channel = channelRef.get();
assert client != null : "Channel future completed successfully with null client";

// Execute any client bootstraps synchronously before marking the Client as successful.
long preBootstrap = System.nanoTime();
logger.debug("Connection to {} successful, running bootstraps...", address);
try {
for (TransportClientBootstrap clientBootstrap : clientBootstraps) {
clientBootstrap.doBootstrap(client);
clientBootstrap.doBootstrap(client, channel);
}
} catch (Exception e) { // catch non-RuntimeExceptions too as bootstrap may be written in Scala
long bootstrapTimeMs = (System.nanoTime() - preBootstrap) / 1000000;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,12 @@

package org.apache.spark.network.sasl;

import javax.security.sasl.Sasl;
import javax.security.sasl.SaslException;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -33,14 +37,24 @@
public class SaslClientBootstrap implements TransportClientBootstrap {
private final Logger logger = LoggerFactory.getLogger(SaslClientBootstrap.class);

private final boolean encrypt;
private final TransportConf conf;
private final String appId;
private final SecretKeyHolder secretKeyHolder;

public SaslClientBootstrap(TransportConf conf, String appId, SecretKeyHolder secretKeyHolder) {
this(conf, appId, secretKeyHolder, false);
}

public SaslClientBootstrap(
TransportConf conf,
String appId,
SecretKeyHolder secretKeyHolder,
boolean encrypt) {
this.conf = conf;
this.appId = appId;
this.secretKeyHolder = secretKeyHolder;
this.encrypt = encrypt;
}

/**
Expand All @@ -49,8 +63,8 @@ public SaslClientBootstrap(TransportConf conf, String appId, SecretKeyHolder sec
* due to mismatch.
*/
@Override
public void doBootstrap(TransportClient client) {
SparkSaslClient saslClient = new SparkSaslClient(appId, secretKeyHolder);
public void doBootstrap(TransportClient client, Channel channel) {
SparkSaslClient saslClient = new SparkSaslClient(appId, secretKeyHolder, encrypt);
try {
byte[] payload = saslClient.firstToken();

Expand All @@ -62,13 +76,26 @@ public void doBootstrap(TransportClient client) {
byte[] response = client.sendRpcSync(buf.array(), conf.saslRTTimeoutMs());
payload = saslClient.response(response);
}

if (encrypt) {
if (!SparkSaslServer.QOP_AUTH_CONF.equals(saslClient.getNegotiatedProperty(Sasl.QOP))) {
throw new RuntimeException(
new SaslException("Encryption requests by negotiated non-encrypted connection."));
}
SaslEncryption.addToChannel(channel, saslClient, conf.maxSaslEncryptedBlockSize());
saslClient = null;
logger.debug("Channel {} configured for SASL encryption.", client);
}
} finally {
try {
// Once authentication is complete, the server will trust all remaining communication.
saslClient.dispose();
} catch (RuntimeException e) {
logger.error("Error while disposing SASL client", e);
if (saslClient != null) {
try {
// Once authentication is complete, the server will trust all remaining communication.
saslClient.dispose();
} catch (RuntimeException e) {
logger.error("Error while disposing SASL client", e);
}
}
}
}

}
Loading

0 comments on commit 38d4e9e

Please sign in to comment.