Skip to content

Commit

Permalink
Implement fallback to IPv4/IPv6 vice-versa for sync
Browse files Browse the repository at this point in the history
This commit changes the way the driver attempts connections to the
ip addresses that are associated with a host. Instead of only fetching
the first ip address and attempting a connection that single ip address,
the driver will now fetch all of the ip addresses and attempt to
connect to each one until a successful connection is made, or until
it is determined that none of the ip addresses can be connected to.
Note that this commit is only for sync. The implementation of this
for async will be handled in a separate commit. A couple of tests
were added for this change as well.

JAVA-2700
  • Loading branch information
thejonathanma committed Sep 27, 2018
1 parent b75842d commit 146c465
Show file tree
Hide file tree
Showing 11 changed files with 197 additions and 31 deletions.
23 changes: 23 additions & 0 deletions driver-core/src/main/com/mongodb/ServerAddress.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.UnknownHostException;
import java.util.ArrayList;
import java.util.List;

/**
* Represents the location of a Mongo server - i.e. server name and port number
Expand Down Expand Up @@ -189,6 +191,27 @@ public InetSocketAddress getSocketAddress() {
}
}

/**
* Gets all underlying socket addresses
*
* @return array of socket addresses
*
* @since 3.9
*/
public List<InetSocketAddress> getSocketAddresses() {
try {
InetAddress[] inetAddresses = InetAddress.getAllByName(host);
List<InetSocketAddress> inetSocketAddressList = new ArrayList<InetSocketAddress>();
for (InetAddress inetAddress : inetAddresses) {
inetSocketAddressList.add(new InetSocketAddress(inetAddress, port));
}

return inetSocketAddressList;
} catch (UnknownHostException e) {
throw new MongoSocketException(e.getMessage(), this, e);
}
}

@Override
public String toString() {
return host + ":" + port;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ public void initChannel(final SocketChannel ch) throws Exception {
SSLEngine engine = getSslContext().createSSLEngine(address.getHost(), address.getPort());
engine.setUseClientMode(true);
SSLParameters sslParameters = engine.getSSLParameters();
enableSni(address, sslParameters);
enableSni(address.getHost(), sslParameters);
if (!sslSettings.isInvalidHostNameAllowed()) {
enableHostNameVerification(sslParameters);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@

package com.mongodb.internal.connection;

import com.mongodb.ServerAddress;

import javax.net.ssl.SNIHostName;
import javax.net.ssl.SNIServerName;
import javax.net.ssl.SSLParameters;
Expand All @@ -37,9 +35,9 @@ final class Java8SniSslHelper implements SniSslHelper {
}

@Override
public void enableSni(final ServerAddress address, final SSLParameters sslParameters) {
public void enableSni(final String host, final SSLParameters sslParameters) {
try {
SNIServerName sniHostName = new SNIHostName(address.getHost());
SNIServerName sniHostName = new SNIHostName(host);
sslParameters.setServerNames(singletonList(sniHostName));
} catch (IllegalArgumentException e) {
// ignore because SNIHostName will throw this for some legit host names for connecting to MongoDB, e.g an IPV6 literal
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,15 @@

package com.mongodb.internal.connection;

import com.mongodb.ServerAddress;

import javax.net.ssl.SSLParameters;

interface SniSslHelper {

/**
* Enable SNI.
*
* @param address the server address
* @param host the server host
* @param sslParameters the SSL parameters
*/
void enableSni(ServerAddress address, SSLParameters sslParameters);
void enableSni(String host, SSLParameters sslParameters);
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package com.mongodb.internal.connection;

import com.mongodb.MongoSocketException;
import com.mongodb.MongoSocketOpenException;
import com.mongodb.MongoSocketReadException;
import com.mongodb.ServerAddress;
Expand All @@ -27,8 +28,11 @@
import org.bson.ByteBuf;

import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.SocketTimeoutException;
import java.nio.ByteBuffer;
import java.nio.channels.SocketChannel;
import java.util.Iterator;
import java.util.List;

import static com.mongodb.assertions.Assertions.isTrue;
Expand All @@ -51,16 +55,32 @@ public SocketChannelStream(final ServerAddress address, final SocketSettings set
}

@Override
public void open() throws IOException {
public void open() {
try {
socketChannel = SocketChannel.open();
SocketStreamHelper.initialize(socketChannel.socket(), address, settings, sslSettings);
socketChannel = initializeSocketChannel();
} catch (IOException e) {
close();
throw new MongoSocketOpenException("Exception opening socket", getAddress(), e);
}
}

private SocketChannel initializeSocketChannel() throws IOException {
Iterator<InetSocketAddress> inetSocketAddresses = address.getSocketAddresses().iterator();
while (inetSocketAddresses.hasNext()) {
SocketChannel socketChannel = SocketChannel.open();
try {
SocketStreamHelper.initialize(socketChannel.socket(), inetSocketAddresses.next(), settings, sslSettings);
return socketChannel;
} catch (SocketTimeoutException e) {
if (!inetSocketAddresses.hasNext()) {
throw e;
}
}
}

throw new MongoSocketException("Exception opening socket", getAddress());
}

@Override
public ByteBuf getBuffer(final int size) {
return bufferProvider.getBuffer(size);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package com.mongodb.internal.connection;

import com.mongodb.MongoSocketException;
import com.mongodb.MongoSocketOpenException;
import com.mongodb.MongoSocketReadException;
import com.mongodb.ServerAddress;
Expand All @@ -30,7 +31,10 @@
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.net.SocketTimeoutException;
import java.util.Iterator;
import java.util.List;

import static com.mongodb.assertions.Assertions.notNull;
Expand All @@ -56,10 +60,9 @@ public SocketStream(final ServerAddress address, final SocketSettings settings,
}

@Override
public void open() throws IOException {
public void open() {
try {
socket = socketFactory.createSocket();
SocketStreamHelper.initialize(socket, address, settings, sslSettings);
socket = initializeSocket();
outputStream = socket.getOutputStream();
inputStream = socket.getInputStream();
} catch (IOException e) {
Expand All @@ -68,6 +71,23 @@ public void open() throws IOException {
}
}

private Socket initializeSocket() throws IOException {
Iterator<InetSocketAddress> inetSocketAddresses = address.getSocketAddresses().iterator();
while (inetSocketAddresses.hasNext()) {
Socket socket = socketFactory.createSocket();
try {
SocketStreamHelper.initialize(socket, inetSocketAddresses.next(), settings, sslSettings);
return socket;
} catch (SocketTimeoutException e) {
if (!inetSocketAddresses.hasNext()) {
throw e;
}
}
}

throw new MongoSocketException("Exception opening socket", getAddress());
}

@Override
public ByteBuf getBuffer(final int size) {
return bufferProvider.getBuffer(size);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@
package com.mongodb.internal.connection;

import com.mongodb.MongoInternalException;
import com.mongodb.ServerAddress;
import com.mongodb.connection.SocketSettings;
import com.mongodb.connection.SslSettings;

import javax.net.ssl.SSLParameters;
import javax.net.ssl.SSLSocket;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.Socket;

import static com.mongodb.internal.connection.SslHelper.enableHostNameVerification;
Expand All @@ -33,8 +33,8 @@
final class SocketStreamHelper {

@SuppressWarnings("deprecation")
static void initialize(final Socket socket, final ServerAddress address, final SocketSettings settings, final SslSettings sslSettings)
throws IOException {
static void initialize(final Socket socket, final InetSocketAddress inetSocketAddress, final SocketSettings settings,
final SslSettings sslSettings) throws IOException {
socket.setTcpNoDelay(true);
socket.setSoTimeout(settings.getReadTimeout(MILLISECONDS));
socket.setKeepAlive(settings.isKeepAlive());
Expand All @@ -54,14 +54,14 @@ static void initialize(final Socket socket, final ServerAddress address, final S
sslParameters = new SSLParameters();
}

enableSni(address, sslParameters);
enableSni(inetSocketAddress.getHostName(), sslParameters);

if (!sslSettings.isInvalidHostNameAllowed()) {
enableHostNameVerification(sslParameters);
}
sslSocket.setSSLParameters(sslParameters);
}
socket.connect(address.getSocketAddress(), settings.getConnectTimeout(MILLISECONDS));
socket.connect(inetSocketAddress, settings.getConnectTimeout(MILLISECONDS));
}

private SocketStreamHelper() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@

package com.mongodb.internal.connection;

import com.mongodb.ServerAddress;

import javax.net.ssl.SSLParameters;
import java.lang.reflect.InvocationTargetException;

Expand Down Expand Up @@ -69,12 +67,12 @@ public static void enableHostNameVerification(final SSLParameters sslParameters)
/**
* Enable SNI if running on Java 8 or later. Otherwise fail silently to enable SNI.
*
* @param address the server address
* @param host the server host
* @param sslParameters the SSL parameters
*/
public static void enableSni(final ServerAddress address, final SSLParameters sslParameters) {
public static void enableSni(final String host, final SSLParameters sslParameters) {
if (SNI_SSL_HELPER != null) {
SNI_SSL_HELPER.enableSni(address, sslParameters);
SNI_SSL_HELPER.enableSni(host, sslParameters);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class SocketStreamHelperSpecification extends Specification {
.build()
when:
SocketStreamHelper.initialize(socket, getPrimary(), socketSettings, SslSettings.builder().build())
SocketStreamHelper.initialize(socket, getPrimary().getSocketAddress(), socketSettings, SslSettings.builder().build())
then:
socket.getTcpNoDelay()
Expand All @@ -61,7 +61,8 @@ class SocketStreamHelperSpecification extends Specification {
Socket socket = SocketFactory.default.createSocket()
when:
SocketStreamHelper.initialize(socket, getPrimary(), SocketSettings.builder().build(), SslSettings.builder().build())
SocketStreamHelper.initialize(socket, getPrimary().getSocketAddress(),
SocketSettings.builder().build(), SslSettings.builder().build())
then:
socket.isConnected()
Expand All @@ -76,7 +77,7 @@ class SocketStreamHelperSpecification extends Specification {
SSLSocket socket = SSLSocketFactory.default.createSocket()
when:
SocketStreamHelper.initialize(socket, getPrimary(), SocketSettings.builder().build(), sslSettings)
SocketStreamHelper.initialize(socket, getPrimary().getSocketAddress(), SocketSettings.builder().build(), sslSettings)
then:
socket.getSSLParameters().endpointIdentificationAlgorithm == (sslSettings.invalidHostNameAllowed ? null : 'HTTPS')
Expand All @@ -96,7 +97,7 @@ class SocketStreamHelperSpecification extends Specification {
SSLSocket socket = SSLSocketFactory.default.createSocket()
when:
SocketStreamHelper.initialize(socket, getPrimary(), SocketSettings.builder().build(), sslSettings)
SocketStreamHelper.initialize(socket, getPrimary().getSocketAddress(), SocketSettings.builder().build(), sslSettings)
then:
socket.getSSLParameters().getServerNames() == [new SNIHostName(getPrimary().getHost())]
Expand All @@ -114,7 +115,8 @@ class SocketStreamHelperSpecification extends Specification {
Socket socket = SocketFactory.default.createSocket()
when:
SocketStreamHelper.initialize(socket, getPrimary(), SocketSettings.builder().build(), SslSettings.builder().enabled(true).build())
SocketStreamHelper.initialize(socket, getPrimary().getSocketAddress(), SocketSettings.builder().build(),
SslSettings.builder().enabled(true).build())
then:
thrown(MongoInternalException)
Expand Down

0 comments on commit 146c465

Please sign in to comment.