Skip to content

Commit

Permalink
Split driver context into public/internal interfaces
Browse files Browse the repository at this point in the history
  • Loading branch information
olim7t committed Apr 5, 2017
1 parent 194d6f8 commit 1e7d46e
Show file tree
Hide file tree
Showing 18 changed files with 420 additions and 170 deletions.
Expand Up @@ -25,19 +25,6 @@
*/
public interface AuthProvider {

/**
* A provider that provides no authentication capability.
*
* <p>This is only useful as a placeholder when no authentication is to be used.
*/
AuthProvider NONE =
(host, serverAuthenticator) -> {
throw new AuthenticationException(
host,
String.format(
"Host %s requires authentication, but no authenticator configured", host));
};

/**
* The authenticator to use when connecting to {@code host}.
*
Expand Down
Expand Up @@ -17,6 +17,7 @@

import com.datastax.oss.driver.api.core.config.CoreDriverOption;
import com.datastax.oss.driver.api.core.config.DriverConfigProfile;
import com.datastax.oss.driver.api.core.context.DriverContext;
import com.google.common.base.Charsets;
import java.net.SocketAddress;
import java.nio.ByteBuffer;
Expand Down Expand Up @@ -46,9 +47,9 @@ public class PlainTextAuthProvider implements AuthProvider {

private final DriverConfigProfile config;

/** Builds a new instance from the driver configuration. */
public PlainTextAuthProvider(DriverConfigProfile config) {
this.config = config;
/** Builds a new instance. */
public PlainTextAuthProvider(DriverContext context) {
this.config = context.config().defaultProfile();
}

@Override
Expand Down
@@ -0,0 +1,34 @@
/*
* Copyright (C) 2017-2017 DataStax Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.datastax.oss.driver.api.core.context;

import com.datastax.oss.driver.api.core.auth.AuthProvider;
import com.datastax.oss.driver.api.core.config.DriverConfig;
import com.datastax.oss.driver.api.core.ssl.SslEngineFactory;
import java.util.Optional;

/** Holds common components that are shared throughout a driver instance. */
public interface DriverContext {

/** The driver's configuration. */
DriverConfig config();

/** The authentication provider, if authentication was configured. */
Optional<AuthProvider> authProvider();

/** The SSL engine factory, if SSL was configured. */
Optional<SslEngineFactory> sslEngineFactory();
}
Expand Up @@ -17,6 +17,7 @@

import com.datastax.oss.driver.api.core.config.CoreDriverOption;
import com.datastax.oss.driver.api.core.config.DriverConfigProfile;
import com.datastax.oss.driver.api.core.context.DriverContext;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.security.NoSuchAlgorithmException;
Expand Down Expand Up @@ -45,16 +46,17 @@
*/
public class DefaultSslEngineFactory implements SslEngineFactory {

private final SSLContext context;
private final SSLContext sslContext;
private final String[] cipherSuites;

/** Builds a new instance from the driver configuration. */
public DefaultSslEngineFactory(DriverConfigProfile config) {
public DefaultSslEngineFactory(DriverContext driverContext) {
try {
this.context = SSLContext.getDefault();
this.sslContext = SSLContext.getDefault();
} catch (NoSuchAlgorithmException e) {
throw new IllegalStateException("Cannot initialize SSL Context", e);
}
DriverConfigProfile config = driverContext.config().defaultProfile();
if (config.isDefined(CoreDriverOption.SSL_CONFIG_CIPHER_SUITES)) {
List<String> list = config.getStringList(CoreDriverOption.SSL_CONFIG_CIPHER_SUITES);
String tmp[] = new String[list.size()];
Expand All @@ -69,9 +71,9 @@ public SSLEngine newSslEngine(SocketAddress remoteEndpoint) {
SSLEngine engine;
if (remoteEndpoint instanceof InetSocketAddress) {
InetSocketAddress address = (InetSocketAddress) remoteEndpoint;
engine = context.createSSLEngine(address.getHostName(), address.getPort());
engine = sslContext.createSSLEngine(address.getHostName(), address.getPort());
} else {
engine = context.createSSLEngine();
engine = sslContext.createSSLEngine();
}
engine.setUseClientMode(true);
if (cipherSuites != null) {
Expand Down
Expand Up @@ -20,8 +20,8 @@
import com.datastax.oss.driver.api.core.UnsupportedProtocolVersionException;
import com.datastax.oss.driver.api.core.config.CoreDriverOption;
import com.datastax.oss.driver.api.core.config.DriverConfigProfile;
import com.datastax.oss.driver.internal.core.DriverContext;
import com.datastax.oss.driver.internal.core.NettyOptions;
import com.datastax.oss.driver.internal.core.context.InternalDriverContext;
import com.datastax.oss.driver.internal.core.context.NettyOptions;
import com.datastax.oss.driver.internal.core.protocol.FrameDecoder;
import com.datastax.oss.driver.internal.core.protocol.FrameEncoder;
import com.google.common.annotations.VisibleForTesting;
Expand All @@ -47,20 +47,20 @@ public class ChannelFactory {

private static final Logger LOG = LoggerFactory.getLogger(ChannelFactory.class);

protected final DriverContext driverContext;
protected final InternalDriverContext internalDriverContext;

/** either set from the configuration, or null and will be negotiated */
@VisibleForTesting ProtocolVersion protocolVersion;

@VisibleForTesting volatile String clusterName;

public ChannelFactory(DriverContext driverContext) {
this.driverContext = driverContext;
public ChannelFactory(InternalDriverContext internalDriverContext) {
this.internalDriverContext = internalDriverContext;

DriverConfigProfile defaultConfig = driverContext.config().defaultProfile();
DriverConfigProfile defaultConfig = internalDriverContext.config().defaultProfile();
if (defaultConfig.isDefined(CoreDriverOption.PROTOCOL_VERSION)) {
String versionName = defaultConfig.getString(CoreDriverOption.PROTOCOL_VERSION);
this.protocolVersion = driverContext.protocolVersionRegistry().fromName(versionName);
this.protocolVersion = internalDriverContext.protocolVersionRegistry().fromName(versionName);
} // else it will be negotiated with the first opened connection
}

Expand All @@ -75,7 +75,7 @@ public CompletionStage<DriverChannel> connect(
currentVersion = protocolVersion;
isNegotiating = false;
} else {
currentVersion = driverContext.protocolVersionRegistry().highestNonBeta();
currentVersion = internalDriverContext.protocolVersionRegistry().highestNonBeta();
isNegotiating = true;
}

Expand All @@ -91,7 +91,7 @@ private void connect(
List<ProtocolVersion> attemptedVersions,
CompletableFuture<DriverChannel> resultFuture) {

NettyOptions nettyOptions = driverContext.nettyOptions();
NettyOptions nettyOptions = internalDriverContext.nettyOptions();

Bootstrap bootstrap =
new Bootstrap()
Expand All @@ -108,7 +108,7 @@ private void connect(
cf -> {
if (connectFuture.isSuccess()) {
DriverChannel driverChannel =
new DriverChannel(connectFuture.channel(), driverContext.writeCoalescer());
new DriverChannel(connectFuture.channel(), internalDriverContext.writeCoalescer());
// If this is the first successful connection, remember the protocol version and
// cluster name for future connections.
if (isNegotiating) {
Expand All @@ -123,7 +123,7 @@ private void connect(
if (error instanceof UnsupportedProtocolVersionException && isNegotiating) {
attemptedVersions.add(currentVersion);
Optional<ProtocolVersion> downgraded =
driverContext.protocolVersionRegistry().downgrade(currentVersion);
internalDriverContext.protocolVersionRegistry().downgrade(currentVersion);
if (downgraded.isPresent()) {
LOG.info(
"Failed to connect with protocol {}, retrying with {}",
Expand All @@ -147,7 +147,7 @@ ChannelInitializer<Channel> initializer(
return new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel channel) throws Exception {
DriverConfigProfile defaultConfigProfile = driverContext.config().defaultProfile();
DriverConfigProfile defaultConfigProfile = internalDriverContext.config().defaultProfile();

long setKeyspaceTimeoutMillis =
defaultConfigProfile.getDuration(
Expand All @@ -164,21 +164,22 @@ protected void initChannel(Channel channel) throws Exception {
new StreamIdGenerator(maxRequestsPerConnection),
setKeyspaceTimeoutMillis);
ProtocolInitHandler initHandler =
new ProtocolInitHandler(driverContext, protocolVersion, clusterName, keyspace);
new ProtocolInitHandler(internalDriverContext, protocolVersion, clusterName, keyspace);

ChannelPipeline pipeline = channel.pipeline();
driverContext
internalDriverContext
.sslHandlerFactory()
.newSslHandler(channel, address)
.map(f -> f.newSslHandler(channel, address))
.map(h -> pipeline.addLast("ssl", h));
pipeline
.addLast("encoder", new FrameEncoder(driverContext.frameCodec()))
.addLast("decoder", new FrameDecoder(driverContext.frameCodec(), maxFrameLength))
.addLast("encoder", new FrameEncoder(internalDriverContext.frameCodec()))
.addLast(
"decoder", new FrameDecoder(internalDriverContext.frameCodec(), maxFrameLength))
.addLast("inflight", inFlightHandler)
.addLast("heartbeat", new HeartbeatHandler(defaultConfigProfile))
.addLast("init", initHandler);

driverContext.nettyOptions().afterChannelInitialized(channel);
internalDriverContext.nettyOptions().afterChannelInitialized(channel);
}
};
}
Expand Down
Expand Up @@ -23,7 +23,7 @@
import com.datastax.oss.driver.api.core.config.CoreDriverOption;
import com.datastax.oss.driver.api.core.config.DriverConfigProfile;
import com.datastax.oss.driver.api.core.connection.ConnectionException;
import com.datastax.oss.driver.internal.core.DriverContext;
import com.datastax.oss.driver.internal.core.context.InternalDriverContext;
import com.datastax.oss.driver.internal.core.util.ProtocolUtils;
import com.datastax.oss.protocol.internal.Message;
import com.datastax.oss.protocol.internal.ProtocolConstants;
Expand All @@ -40,6 +40,7 @@
import com.datastax.oss.protocol.internal.util.Bytes;
import com.google.common.base.Charsets;
import io.netty.channel.ChannelHandlerContext;
import java.net.SocketAddress;
import java.nio.ByteBuffer;
import java.util.List;
import java.util.concurrent.TimeUnit;
Expand All @@ -55,22 +56,22 @@ class ProtocolInitHandler extends ConnectInitHandler {
private static final Query CLUSTER_NAME_QUERY =
new Query("SELECT cluster_name FROM system.local");

private final DriverContext driverContext;
private final InternalDriverContext internalDriverContext;
private final long timeoutMillis;
private final ProtocolVersion initialProtocolVersion;
private final CqlIdentifier keyspaceName;
// might be null if this is the first channel to this cluster
private final String expectedClusterName;

ProtocolInitHandler(
DriverContext driverContext,
InternalDriverContext internalDriverContext,
ProtocolVersion protocolVersion,
String expectedClusterName,
CqlIdentifier keyspaceName) {

this.driverContext = driverContext;
this.internalDriverContext = internalDriverContext;

DriverConfigProfile defaultConfig = driverContext.config().defaultProfile();
DriverConfigProfile defaultConfig = internalDriverContext.config().defaultProfile();

this.timeoutMillis =
defaultConfig.getDuration(
Expand Down Expand Up @@ -138,10 +139,7 @@ void onResponse(Message response) {
send();
} else if (step == Step.STARTUP && response instanceof Authenticate) {
Authenticate authenticate = (Authenticate) response;
authenticator =
driverContext
.authProvider()
.newAuthenticator(channel.remoteAddress(), authenticate.authenticator);
authenticator = buildAuthenticator(channel.remoteAddress(), authenticate.authenticator);
authenticator
.initialResponse()
.whenCompleteAsync(
Expand Down Expand Up @@ -240,6 +238,19 @@ void onResponse(Message response) {
void fail(Throwable cause) {
setConnectFailure(cause);
}

private Authenticator buildAuthenticator(SocketAddress address, String authenticator) {
return internalDriverContext
.authProvider()
.map(p -> p.newAuthenticator(address, authenticator))
.orElseThrow(
() ->
new AuthenticationException(
address,
String.format(
"Host %s requires authentication (%s), but no authenticator configured",
address, authenticator)));
}
}

// TODO we'll probably need a lightweight ResultSet implementation for internal uses, but this is good for now
Expand Down

0 comments on commit 1e7d46e

Please sign in to comment.