Skip to content

Commit

Permalink
Fix Bolt handshake write handling and timeout management (#1528) (#1546)
Browse files Browse the repository at this point in the history
  • Loading branch information
injectives authored Apr 2, 2024
1 parent 2a20e31 commit d56af61
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,17 @@
package org.neo4j.driver.internal.async.connection;

import static java.lang.String.format;
import static org.neo4j.driver.internal.async.connection.BoltProtocolUtil.handshakeBuf;
import static org.neo4j.driver.internal.async.connection.BoltProtocolUtil.handshakeString;

import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.ChannelPromise;
import javax.net.ssl.SSLHandshakeException;
import org.neo4j.driver.Logger;
import org.neo4j.driver.Logging;
import org.neo4j.driver.exceptions.SecurityException;
import org.neo4j.driver.exceptions.ServiceUnavailableException;
import org.neo4j.driver.internal.BoltServerAddress;
import org.neo4j.driver.internal.logging.ChannelActivityLogger;
Expand Down Expand Up @@ -61,7 +62,18 @@ public void operationComplete(ChannelFuture future) {
ChannelPipeline pipeline = channel.pipeline();
pipeline.addLast(new HandshakeHandler(pipelineBuilder, handshakeCompletedPromise, logging));
log.debug("C: [Bolt Handshake] %s", handshakeString());
channel.writeAndFlush(handshakeBuf(), channel.voidPromise());
channel.writeAndFlush(BoltProtocolUtil.handshakeBuf()).addListener(f -> {
if (!f.isSuccess()) {
Throwable error = f.cause();
if (error instanceof SSLHandshakeException) {
error = new SecurityException("Failed to establish secured connection with the server", error);
} else {
error = new ServiceUnavailableException(
String.format("Unable to write Bolt handshake to %s.", this.address), error);
}
this.handshakeCompletedPromise.setFailure(error);
}
});
} else {
handshakeCompletedPromise.setFailure(databaseUnavailableError(address, future.cause()));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,11 @@ private void installHandshakeCompletedListeners(

// remove timeout handler from the pipeline once TLS and Bolt handshakes are completed. regular protocol
// messages will flow next and we do not want to have read timeout for them
handshakeCompleted.addListener(future -> pipeline.remove(ConnectTimeoutHandler.class));
handshakeCompleted.addListener(future -> {
if (future.isSuccess()) {
pipeline.remove(ConnectTimeoutHandler.class);
}
});

// add listener that sends an INIT message. connection is now fully established. channel pipeline if fully
// set to send/receive messages for a selected protocol version
Expand Down
2 changes: 2 additions & 0 deletions driver/src/test/java/org/neo4j/driver/GraphDatabaseTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.neo4j.driver.exceptions.ServiceUnavailableException;
import org.neo4j.driver.internal.BoltServerAddress;
Expand Down Expand Up @@ -152,6 +153,7 @@ void shouldFailToCreateUnencryptedDriverWhenServerDoesNotRespond() throws IOExce
}

@Test
@Disabled("TLS actually fails, the test setup is not valid")
void shouldFailToCreateEncryptedDriverWhenServerDoesNotRespond() throws IOException {
testFailureWhenServerDoesNotRespond(true);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
import java.util.concurrent.TimeUnit;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.neo4j.driver.AuthToken;
Expand Down Expand Up @@ -158,6 +159,7 @@ void shouldFailWhenProtocolNegotiationTakesTooLong() throws Exception {
}

@Test
@Disabled("TLS actually fails, the test setup is not valid")
void shouldFailWhenTLSHandshakeTakesTooLong() throws Exception {
// run with TLS so that TLS handshake is the very first operation after connection is established
testReadTimeoutOnConnect(trustAllCertificates());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ void shouldOperateWithEncryptionWhenItIsOptionalInTheDatabase() {

@Test
void shouldFailWithoutEncryptionWhenItIsRequiredInTheDatabase() {
testMismatchingEncryption(BoltTlsLevel.REQUIRED, false);
testMismatchingEncryption(BoltTlsLevel.REQUIRED, false, "Connection to the database terminated");
}

@Test
Expand All @@ -74,7 +74,7 @@ void shouldOperateWithEncryptionWhenConfiguredUsingBoltSscURI() {

@Test
void shouldFailWithEncryptionWhenItIsDisabledInTheDatabase() {
testMismatchingEncryption(BoltTlsLevel.DISABLED, true);
testMismatchingEncryption(BoltTlsLevel.DISABLED, true, "Unable to write Bolt handshake to");
}

@Test
Expand Down Expand Up @@ -110,7 +110,7 @@ private void testMatchingEncryption(BoltTlsLevel tlsLevel, boolean driverEncrypt
}
}

private void testMismatchingEncryption(BoltTlsLevel tlsLevel, boolean driverEncrypted) {
private void testMismatchingEncryption(BoltTlsLevel tlsLevel, boolean driverEncrypted, String errorMessage) {
Map<String, String> tlsConfig = new HashMap<>();
tlsConfig.put(Neo4jSettings.BOLT_TLS_LEVEL, tlsLevel.toString());
neo4j.deleteAndStartNeo4j(tlsConfig);
Expand All @@ -120,7 +120,7 @@ private void testMismatchingEncryption(BoltTlsLevel tlsLevel, boolean driverEncr
ServiceUnavailableException.class, () -> GraphDatabase.driver(neo4j.uri(), neo4j.authToken(), config)
.verifyConnectivity());

assertThat(e.getMessage(), startsWith("Connection to the database terminated"));
assertThat(e.getMessage(), startsWith(errorMessage));
}

private static Config newConfig(boolean withEncryption) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
package org.neo4j.driver.internal.async.connection;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
Expand All @@ -29,7 +31,9 @@

import io.netty.channel.ChannelPromise;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.util.concurrent.Future;
import java.io.IOException;
import java.util.concurrent.CompletableFuture;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test;
import org.neo4j.driver.exceptions.ServiceUnavailableException;
Expand Down Expand Up @@ -73,6 +77,25 @@ void shouldWriteHandshakeWhenChannelConnected() {
assertEquals(handshakeBuf(), channel.readOutbound());
}

@Test
void shouldCompleteHandshakePromiseExceptionallyOnWriteFailure() {
ChannelPromise handshakeCompletedPromise = channel.newPromise();
ChannelConnectedListener listener = newListener(handshakeCompletedPromise);
ChannelPromise channelConnectedPromise = channel.newPromise();
channelConnectedPromise.setSuccess();
channel.close();

listener.operationComplete(channelConnectedPromise);

assertTrue(handshakeCompletedPromise.isDone());
CompletableFuture<Future<?>> future = new CompletableFuture<>();
handshakeCompletedPromise.addListener(future::complete);
Future<?> handshakeFuture = future.join();
assertTrue(handshakeFuture.isDone());
assertFalse(handshakeFuture.isSuccess());
assertInstanceOf(ServiceUnavailableException.class, handshakeFuture.cause());
}

private static ChannelConnectedListener newListener(ChannelPromise handshakeCompletedPromise) {
return new ChannelConnectedListener(
LOCAL_DEFAULT, new ChannelPipelineBuilderImpl(), handshakeCompletedPromise, DEV_NULL_LOGGING);
Expand Down

0 comments on commit d56af61

Please sign in to comment.