Skip to content

Commit

Permalink
ISPN-15069 More flexbile SNI from the client
Browse files Browse the repository at this point in the history
* Per-cluster SNI
* Implict SNI if not set
* Also add a method to return current cluster
  • Loading branch information
tristantarrant committed Jul 21, 2023
1 parent 90c6140 commit 317c584
Show file tree
Hide file tree
Showing 9 changed files with 67 additions and 26 deletions.
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
package org.infinispan.client.hotrod;

import jakarta.transaction.TransactionManager;

import org.infinispan.client.hotrod.configuration.Configuration;
import org.infinispan.client.hotrod.configuration.TransactionMode;
import org.infinispan.commons.api.BasicCacheContainer;
import org.infinispan.commons.marshall.Marshaller;

import jakarta.transaction.TransactionManager;

public interface RemoteCacheContainer extends BasicCacheContainer {

/**
Expand Down Expand Up @@ -152,6 +152,13 @@ <K, V> RemoteCache<K, V> getCache(String cacheName, boolean forceReturnValue, Tr
*/
boolean switchToDefaultCluster();

/**
* Returns the name of the currently active cluster.
*
* @return the name of the active cluster
*/
String getCurrentClusterName();

Marshaller getMarshaller();

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,11 @@ public boolean switchToDefaultCluster() {
return channelFactory.manualSwitchToCluster(ChannelFactory.DEFAULT_CLUSTER_NAME);
}

@Override
public String getCurrentClusterName() {
return channelFactory.getCurrentClusterName();
}

private Properties loadFromStream(InputStream stream) {
Properties properties = new Properties();
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@ public class ClusterConfiguration {
private final List<ServerConfiguration> serverCluster;
private final String clusterName;
private final ClientIntelligence intelligence;
private final String sniHostName;

public ClusterConfiguration(List<ServerConfiguration> serverCluster, String clusterName, ClientIntelligence intelligence) {
public ClusterConfiguration(List<ServerConfiguration> serverCluster, String clusterName, ClientIntelligence intelligence, String sniHostName) {
this.serverCluster = serverCluster;
this.clusterName = clusterName;
this.intelligence = intelligence;
this.sniHostName = sniHostName;
}

public List<ServerConfiguration> getCluster() {
Expand All @@ -30,12 +32,17 @@ public ClientIntelligence getClientIntelligence() {
return intelligence;
}

public String getSniHostName() {
return sniHostName;
}

@Override
public String toString() {
return "ClusterConfiguration{" +
"serverCluster=" + Util.toStr(serverCluster) +
", clusterName='" + clusterName + '\'' +
", intelligence=" + intelligence +
", sniHostName=" + sniHostName +
'}';
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ public class ClusterConfigurationBuilder extends AbstractConfigurationChildBuild
private final List<ServerConfigurationBuilder> servers = new ArrayList<>();
private final String clusterName;
private ClientIntelligence intelligence;
private String sniHostName;

protected ClusterConfigurationBuilder(ConfigurationBuilder builder, String clusterName) {
super(builder);
Expand Down Expand Up @@ -53,6 +54,11 @@ public ClusterConfigurationBuilder clusterClientIntelligence(ClientIntelligence
return this;
}

public ClusterConfigurationBuilder clusterSniHostName(String clusterSniHostName) {
this.sniHostName = clusterSniHostName;
return this;
}

@Override
public void validate() {
if (clusterName == null || clusterName.isEmpty()) {
Expand All @@ -70,7 +76,7 @@ public void validate() {
public ClusterConfiguration create() {
List<ServerConfiguration> serverCluster = servers.stream()
.map(ServerConfigurationBuilder::create).collect(Collectors.toList());
return new ClusterConfiguration(serverCluster, clusterName, intelligence);
return new ClusterConfiguration(serverCluster, clusterName, intelligence, sniHostName);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,22 @@ public class ClusterInfo {
// updates won't be allowed to apply since they refer to older views.
private final int topologyAge;
private final ClientIntelligence intelligence;
private final String sniHostName;

public ClusterInfo(String clusterName, List<InetSocketAddress> servers, ClientIntelligence intelligence) {
this(clusterName, servers, -1, intelligence);
public ClusterInfo(String clusterName, List<InetSocketAddress> servers, ClientIntelligence intelligence, String sniHostName) {
this(clusterName, servers, -1, intelligence, sniHostName);
}

private ClusterInfo(String clusterName, List<InetSocketAddress> servers, int topologyAge, ClientIntelligence intelligence) {
private ClusterInfo(String clusterName, List<InetSocketAddress> servers, int topologyAge, ClientIntelligence intelligence, String sniHostName) {
this.clusterName = clusterName;
this.servers = Immutables.immutableListCopy(servers);
this.topologyAge = topologyAge;
this.intelligence = Objects.requireNonNull(intelligence);
this.sniHostName = sniHostName;
}

public ClusterInfo withTopologyAge(int topologyAge) {
return new ClusterInfo(clusterName, servers, topologyAge, intelligence);
return new ClusterInfo(clusterName, servers, topologyAge, intelligence, sniHostName);
}

public String getName() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ public class ChannelFactory {
@GuardedBy("lock")
private final Set<SocketAddress> failedServers = new HashSet<>();
private final CodecHolder codecHolder;
private AddressResolverGroup<?> dnsResolver;

public ChannelFactory(CodecHolder codecHolder) {
this.codecHolder = codecHolder;
Expand All @@ -123,12 +124,17 @@ public void start(Configuration configuration, Marshaller marshaller, ExecutorSe
// Note that each event loop opens a selector which counts
int maxExecutors = Math.min(asyncThreads, eventLoopThreads);
this.eventLoopGroup = configuration.transportFactory().createEventLoopGroup(maxExecutors, executorService);
DnsNameResolverBuilder builder = new DnsNameResolverBuilder()
.channelType(configuration.transportFactory().datagramChannelClass())
.ttl(configuration.dnsResolverMinTTL(), configuration.dnsResolverMaxTTL())
.negativeTtl(configuration.dnsResolverNegativeTTL());
this.dnsResolver = new RoundRobinDnsAddressResolverGroup(builder);

List<InetSocketAddress> initialServers = new ArrayList<>();
for (ServerConfiguration server : configuration.servers()) {
initialServers.add(InetSocketAddress.createUnresolved(server.host(), server.port()));
}
ClusterInfo mainCluster = new ClusterInfo(DEFAULT_CLUSTER_NAME, initialServers, configuration.clientIntelligence());
ClusterInfo mainCluster = new ClusterInfo(DEFAULT_CLUSTER_NAME, initialServers, configuration.clientIntelligence(), configuration.security().ssl().sniHostName());
List<ClusterInfo> clustersDefinitions = new ArrayList<>();
if (log.isDebugEnabled()) {
log.debugf("Statically configured servers: %s", initialServers);
Expand All @@ -145,8 +151,10 @@ public void start(Configuration configuration, Marshaller marshaller, ExecutorSe
ClientIntelligence intelligence = clusterConfiguration.getClientIntelligence() != null ?
clusterConfiguration.getClientIntelligence() :
configuration.clientIntelligence();

String sniHostName = clusterConfiguration.getSniHostName() != null ? clusterConfiguration.getSniHostName() : configuration.security().ssl().sniHostName();
ClusterInfo alternateCluster =
new ClusterInfo(clusterConfiguration.getClusterName(), alternateServers, intelligence);
new ClusterInfo(clusterConfiguration.getClusterName(), alternateServers, intelligence, sniHostName);
log.debugf("Add secondary cluster: %s", alternateCluster);
clustersDefinitions.add(alternateCluster);
}
Expand Down Expand Up @@ -187,12 +195,7 @@ public MarshallerRegistry getMarshallerRegistry() {
}

private ChannelPool newPool(SocketAddress address) {
log.debugf("Creating new channel pool for %s", address);
DnsNameResolverBuilder builder = new DnsNameResolverBuilder()
.channelType(configuration.transportFactory().datagramChannelClass())
.ttl(configuration.dnsResolverMinTTL(), configuration.dnsResolverMaxTTL())
.negativeTtl(configuration.dnsResolverNegativeTTL());
AddressResolverGroup<?> dnsResolver = new RoundRobinDnsAddressResolverGroup(builder);
log.debugf("Creating new channel pool for %s", address);
Bootstrap bootstrap = new Bootstrap()
.group(eventLoopGroup)
.channel(configuration.transportFactory().socketChannelClass())
Expand All @@ -210,7 +213,7 @@ private ChannelPool newPool(SocketAddress address) {
}

public ChannelInitializer createChannelInitializer(SocketAddress address, Bootstrap bootstrap) {
return new ChannelInitializer(bootstrap, address, operationsFactory, configuration, this);
return new ChannelInitializer(bootstrap, address, operationsFactory, configuration, this, findCluster(getCurrentClusterName()));
}

protected ChannelPool createChannelPool(Bootstrap bootstrap, ChannelInitializer channelInitializer, SocketAddress address) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package org.infinispan.client.hotrod.impl.transport.netty;

import java.io.File;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.security.Principal;
import java.security.PrivilegedActionException;
Expand All @@ -27,6 +28,7 @@
import org.infinispan.client.hotrod.configuration.Configuration;
import org.infinispan.client.hotrod.configuration.SslConfiguration;
import org.infinispan.client.hotrod.impl.operations.OperationsFactory;
import org.infinispan.client.hotrod.impl.topology.ClusterInfo;
import org.infinispan.client.hotrod.logging.Log;
import org.infinispan.client.hotrod.logging.LogFactory;
import org.infinispan.commons.CacheConfigurationException;
Expand All @@ -53,6 +55,7 @@ class ChannelInitializer extends io.netty.channel.ChannelInitializer<Channel> {
private final OperationsFactory operationsFactory;
private final Configuration configuration;
private final ChannelFactory channelFactory;
private final ClusterInfo cluster;
private ChannelPool channelPool;
private volatile boolean isFirstPing = true;

Expand All @@ -76,12 +79,13 @@ class ChannelInitializer extends io.netty.channel.ChannelInitializer<Channel> {
SECURITY_PROVIDERS = providers.toArray(new Provider[0]);
}

ChannelInitializer(Bootstrap bootstrap, SocketAddress unresolvedAddress, OperationsFactory operationsFactory, Configuration configuration, ChannelFactory channelFactory) {
ChannelInitializer(Bootstrap bootstrap, SocketAddress unresolvedAddress, OperationsFactory operationsFactory, Configuration configuration, ChannelFactory channelFactory, ClusterInfo cluster) {
this.bootstrap = bootstrap;
this.unresolvedAddress = unresolvedAddress;
this.operationsFactory = operationsFactory;
this.configuration = configuration;
this.channelFactory = channelFactory;
this.cluster = cluster;
}

CompletableFuture<Channel> createChannel() {
Expand Down Expand Up @@ -175,14 +179,21 @@ private void initSsl(Channel channel) {
} else {
sslContext = new JdkSslContext(ssl.sslContext(), true, ClientAuth.NONE);
}

SslHandler sslHandler = sslContext.newHandler(channel.alloc(), ssl.sniHostName(), -1);
if (ssl.sniHostName() != null) {
SSLParameters sslParameters = sslHandler.engine().getSSLParameters();
sslParameters.setServerNames(Collections.singletonList(new SNIHostName(ssl.sniHostName())));
sslHandler.engine().setSSLParameters(sslParameters);
String sniHostName;
if (cluster != null && cluster.getName() != null) {
sniHostName = cluster.getName();
} else if (ssl.sniHostName() != null) {
sniHostName = ssl.sniHostName();
} else {
sniHostName = ((InetSocketAddress) unresolvedAddress).getHostString();
}
channel.pipeline().addFirst(sslHandler,
SslHandshakeExceptionHandler.INSTANCE);
SSLParameters sslParameters = sslHandler.engine().getSSLParameters();
sslParameters.setServerNames(Collections.singletonList(new SNIHostName(sniHostName)));
sslHandler.engine().setSSLParameters(sslParameters);

channel.pipeline().addFirst(sslHandler, SslHandshakeExceptionHandler.INSTANCE);
}

private void initAuthentication(Channel channel, AuthenticationConfiguration authentication) throws PrivilegedActionException, SaslException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ public class RoundRobinBalancingStrategy implements FailoverRequestBalancingStra

@Override
public void setServers(Collection<SocketAddress> servers) {
this.servers = servers.toArray(new SocketAddress[servers.size()]);
this.servers = servers.toArray(new SocketAddress[0]);
// Always start with a random server after a topology update
index = ThreadLocalRandom.current().nextInt(this.servers.length);
if (log.isTraceEnabled()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ public TestChannelFactory(CodecHolder codecHolder) {

@Override
public ChannelInitializer createChannelInitializer(SocketAddress address, Bootstrap bootstrap) {
return new ChannelInitializer(bootstrap, address, getOperationsFactory(), getConfiguration(), this) {
return new ChannelInitializer(bootstrap, address, getOperationsFactory(), getConfiguration(), this, null) {
@Override
protected void initChannel(Channel channel) throws Exception {
super.initChannel(channel);
Expand Down

0 comments on commit 317c584

Please sign in to comment.