diff --git a/driver/src/main/java/org/neo4j/driver/internal/connector/socket/SocketClient.java b/driver/src/main/java/org/neo4j/driver/internal/connector/socket/SocketClient.java index 41aee7d6d9..d643f50000 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/connector/socket/SocketClient.java +++ b/driver/src/main/java/org/neo4j/driver/internal/connector/socket/SocketClient.java @@ -215,6 +215,7 @@ public static ByteChannel create( String host, int port, Config config, Logger l SocketChannel soChannel = SocketChannel.open(); soChannel.setOption( StandardSocketOptions.SO_REUSEADDR, true ); soChannel.setOption( StandardSocketOptions.SO_KEEPALIVE, true ); + soChannel.connect( new InetSocketAddress( host, port ) ); ByteChannel channel; diff --git a/driver/src/main/java/org/neo4j/driver/internal/connector/socket/TLSSocketChannel.java b/driver/src/main/java/org/neo4j/driver/internal/connector/socket/TLSSocketChannel.java index 09d1c773df..a370ef4765 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/connector/socket/TLSSocketChannel.java +++ b/driver/src/main/java/org/neo4j/driver/internal/connector/socket/TLSSocketChannel.java @@ -21,14 +21,12 @@ import java.io.IOException; import java.nio.ByteBuffer; import java.nio.channels.ByteChannel; -import java.nio.channels.SocketChannel; import java.security.GeneralSecurityException; import javax.net.ssl.SSLContext; import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLEngineResult; import javax.net.ssl.SSLEngineResult.HandshakeStatus; import javax.net.ssl.SSLEngineResult.Status; -import javax.net.ssl.SSLSession; import org.neo4j.driver.internal.spi.Logger; import org.neo4j.driver.internal.util.BytePrinter; @@ -51,10 +49,9 @@ */ public class TLSSocketChannel implements ByteChannel { - private final SocketChannel channel; // The real channel the data is sent to and read from + private final ByteChannel channel; // The real channel the data is sent to and read from private final Logger logger; - private final SSLContext sslContext; private SSLEngine sslEngine; /** The buffer for network data */ @@ -66,34 +63,36 @@ public class TLSSocketChannel implements ByteChannel private static final ByteBuffer DUMMY_BUFFER = ByteBuffer.allocateDirect( 0 ); - public TLSSocketChannel( String host, int port, SocketChannel channel, Logger logger, + public TLSSocketChannel( String host, int port, ByteChannel channel, Logger logger, TrustStrategy trustStrategy ) throws GeneralSecurityException, IOException { - logger.debug( "TLS connection enabled" ); - this.logger = logger; - this.channel = channel; - this.channel.configureBlocking( true ); + this(channel, logger, + createSSLEngine( host, port, new SSLContextFactory( host, port, trustStrategy, logger ).create() ) ); - sslContext = new SSLContextFactory( host, port, trustStrategy, logger ).create(); - createSSLEngine( host, port ); - createBuffers(); - runHandshake(); - logger.debug( "TLS connection established" ); } - /** Used in internal tests only */ - TLSSocketChannel( SocketChannel channel, Logger logger, SSLEngine sslEngine, + public TLSSocketChannel( ByteChannel channel, Logger logger, SSLEngine sslEngine ) throws GeneralSecurityException, IOException + { + this(channel, logger, sslEngine, + ByteBuffer.allocateDirect( sslEngine.getSession().getApplicationBufferSize() ), + ByteBuffer.allocateDirect( sslEngine.getSession().getPacketBufferSize() ), + ByteBuffer.allocateDirect( sslEngine.getSession().getApplicationBufferSize() ), + ByteBuffer.allocateDirect( sslEngine.getSession().getPacketBufferSize() ) ); + } + + TLSSocketChannel( ByteChannel channel, Logger logger, SSLEngine sslEngine, ByteBuffer plainIn, ByteBuffer cipherIn, ByteBuffer plainOut, ByteBuffer cipherOut ) throws GeneralSecurityException, IOException { - logger.debug( "Testing TLS buffers" ); this.logger = logger; this.channel = channel; - - this.sslContext = SSLContext.getInstance( "TLS" ); this.sslEngine = sslEngine; - resetBuffers( plainIn, cipherIn, plainOut, cipherOut ); // reset buffer size + this.plainIn = plainIn; + this.cipherIn = cipherIn; + this.plainOut = plainOut; + this.cipherOut = cipherOut; + runHandshake(); } /** @@ -126,7 +125,6 @@ private void runHandshake() throws IOException case NEED_UNWRAP: // Unwrap the ssl packet to value ssl handshake information handshakeStatus = unwrap( DUMMY_BUFFER ); - plainIn.clear(); break; case NEED_WRAP: // Wrap the app packet into an ssl packet to add ssl handshake information @@ -134,9 +132,6 @@ private void runHandshake() throws IOException break; } } - - plainIn.clear(); - plainOut.clear(); } private HandshakeStatus runDelegatedTasks() @@ -185,10 +180,11 @@ private HandshakeStatus unwrap( ByteBuffer buffer ) throws IOException } cipherIn.flip(); - Status status = null; + Status status; do { - status = sslEngine.unwrap( cipherIn, plainIn ).getStatus(); + SSLEngineResult unwrapResult = sslEngine.unwrap( cipherIn, plainIn ); + status = unwrapResult.getStatus(); // Possible status here: // OK - good // BUFFER_OVERFLOW - we need to enlarge* plainIn @@ -244,17 +240,13 @@ private HandshakeStatus unwrap( ByteBuffer buffer ) throws IOException // Otherwise, make room for reading more data from channel cipherIn.compact(); } - - // I skipped the following check as it "should not" happen at all: - // The channel should not provide us ciphered bytes that cannot hold in the channel buffer at all - // if( cipherIn.remaining() == 0 ) - // {throw new ClientException( "cannot enlarge as it already reached the limit" );} - - // Obtain more inbound network data for cipherIn, - // then retry the operation. return handshakeStatus; // old status + case CLOSED: + // RFC 2246 #7.2.1 requires us to stop accepting input. + sslEngine.closeInbound(); + break; default: - throw new ClientException( "Got unexpected status " + status ); + throw new ClientException( "Got unexpected status " + status + ", " + unwrapResult ); } } while ( cipherIn.hasRemaining() ); /* Remember we are doing blocking reading. @@ -285,7 +277,10 @@ private HandshakeStatus wrap( ByteBuffer buffer ) throws IOException case OK: handshakeStatus = runDelegatedTasks(); cipherOut.flip(); - channel.write( cipherOut ); + while(cipherOut.hasRemaining()) + { + channel.write( cipherOut ); + } cipherOut.clear(); break; case BUFFER_OVERFLOW: @@ -344,42 +339,17 @@ static int bufferCopy( ByteBuffer from, ByteBuffer to ) return maxTransfer; } - /** - * Create network buffers and application buffers - * - * @throws IOException - */ - private void createBuffers() throws IOException - { - SSLSession session = sslEngine.getSession(); - int appBufferSize = session.getApplicationBufferSize(); - int netBufferSize = session.getPacketBufferSize(); - - plainOut = ByteBuffer.allocateDirect( appBufferSize ); - plainIn = ByteBuffer.allocateDirect( appBufferSize ); - cipherOut = ByteBuffer.allocateDirect( netBufferSize ); - cipherIn = ByteBuffer.allocateDirect( netBufferSize ); - } - - /** Should only be used in tests */ - void resetBuffers( ByteBuffer plainIn, ByteBuffer cipherIn, ByteBuffer plainOut, ByteBuffer cipherOut ) - { - this.plainIn = plainIn; - this.cipherIn = cipherIn; - this.plainOut = plainOut; - this.cipherOut = cipherOut; - } - /** * Create SSLEngine with the SSLContext just created. - * * @param host * @param port + * @param sslContext */ - private void createSSLEngine( String host, int port ) + private static SSLEngine createSSLEngine( String host, int port, SSLContext sslContext ) { - sslEngine = sslContext.createSSLEngine( host, port ); + SSLEngine sslEngine = sslContext.createSSLEngine( host, port ); sslEngine.setUseClientMode( true ); + return sslEngine; } @Override @@ -431,33 +401,46 @@ public boolean isOpen() @Override public void close() throws IOException { - plainOut.clear(); - // Indicate that application is done with engine - sslEngine.closeOutbound(); - - while ( !sslEngine.isOutboundDone() ) + try { - // Get close message - SSLEngineResult res = sslEngine.wrap( plainOut, cipherOut ); - - // Check res statuses + plainOut.clear(); + // Indicate that application is done with engine + sslEngine.closeOutbound(); - // Send close message to peer - cipherOut.flip(); - while ( cipherOut.hasRemaining() ) + while ( !sslEngine.isOutboundDone() ) { - int num = channel.write( cipherOut ); - if ( num == -1 ) + // Get close message + SSLEngineResult res = sslEngine.wrap( plainOut, cipherOut ); + + // Check res statuses + + // Send close message to peer + cipherOut.flip(); + while ( cipherOut.hasRemaining() ) { - // handle closed channel - break; + int num = channel.write( cipherOut ); + if ( num == -1 ) + { + // handle closed channel + break; + } } + cipherOut.clear(); } - cipherOut.clear(); + // Close transport + channel.close(); + logger.debug( "TLS connection closed" ); + } + catch(IOException e) + { + // Treat this as ok - the connection is closed, even if the TLS session did not exit cleanly. + logger.warn( "TLS socket could not be closed cleanly: '"+e.getMessage()+"'", e ); } - // Close transport - channel.close(); - logger.debug( "TLS connection closed" ); } + @Override + public String toString() + { + return "TLSSocketChannel{plainIn: " + plainIn + ", cipherIn:" + cipherIn + "}"; + } } diff --git a/driver/src/test/java/org/neo4j/driver/internal/connector/socket/TLSSocketChannelTest.java b/driver/src/test/java/org/neo4j/driver/internal/connector/socket/TLSSocketChannelTest.java deleted file mode 100644 index 7d155eee4d..0000000000 --- a/driver/src/test/java/org/neo4j/driver/internal/connector/socket/TLSSocketChannelTest.java +++ /dev/null @@ -1,370 +0,0 @@ -/** - * Copyright (c) 2002-2016 "Neo Technology," - * Network Engine for Objects in Lund AB [http://neotechnology.com] - * - * This file is part of Neo4j. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.neo4j.driver.internal.connector.socket; - -import junit.framework.TestCase; -import org.junit.BeforeClass; -import org.junit.Test; -import org.mockito.invocation.InvocationOnMock; -import org.mockito.stubbing.Answer; - -import java.nio.ByteBuffer; -import java.nio.channels.SocketChannel; -import javax.net.ssl.SSLEngine; -import javax.net.ssl.SSLEngineResult; -import javax.net.ssl.SSLSession; - -import org.neo4j.driver.internal.spi.Logger; -import org.neo4j.driver.internal.util.BytePrinter; - -import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING; -import static javax.net.ssl.SSLEngineResult.Status.BUFFER_OVERFLOW; -import static javax.net.ssl.SSLEngineResult.Status.BUFFER_UNDERFLOW; -import static javax.net.ssl.SSLEngineResult.Status.OK; -import static junit.framework.TestCase.assertEquals; -import static junit.framework.TestCase.fail; -import static org.mockito.Matchers.any; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; - -/** - * Tests related to the buffer uses in SSLSocketChannel - */ -public class TLSSocketChannelTest -{ - private ByteBuffer plainIn; - private ByteBuffer cipherIn; - private ByteBuffer cipherOut; - - private static int bufferSize; - private static SSLEngine sslEngine; - - @BeforeClass - public static void setup() - { - sslEngine = mock( SSLEngine.class ); - SSLSession session = mock( SSLSession.class ); - when( sslEngine.getSession() ).thenReturn( session ); - - // The strategy to enlarge the application buffer: double the size - doAnswer( new Answer() - { - @Override - public Integer answer( InvocationOnMock invocation ) throws Throwable - { - bufferSize *= 2; - if ( bufferSize > 8 ) - { - fail( "We do not need a application buffer greater than 8 for all the SSL buffer tests" ); - } - return bufferSize; - } - } ).when( session ).getApplicationBufferSize(); - - // The strategy to enlarge the network buffer: double the size - doAnswer( new Answer() - { - @Override - public Integer answer( InvocationOnMock invocation ) throws Throwable - { - bufferSize *= 2; - if ( bufferSize > 8 ) - { - fail( "We do not need a network buffer greater than 8 for all the SSL buffer tests" ); - } - return bufferSize; - } - } ).when( session ).getPacketBufferSize(); - } - - @Test - public void shouldEnlargeApplicationInputBuffer() throws Throwable - { - // Given - bufferSize = 2; - plainIn = ByteBuffer.allocate( bufferSize ); - ByteBuffer cipherIn = mock( ByteBuffer.class ); - ByteBuffer plainOut = mock( ByteBuffer.class ); - ByteBuffer cipherOut = mock( ByteBuffer.class ); - - SocketChannel channel = mock( SocketChannel.class ); - Logger logger = mock( Logger.class ); - - - TLSSocketChannel sslChannel = - new TLSSocketChannel( channel, logger, sslEngine, plainIn, cipherIn, plainOut, cipherOut ); - - // Write 00 01 02 03 04 05 06 into plainIn, simulating deciphering some bytes - doAnswer( new Answer() - { - @Override - public SSLEngineResult answer( InvocationOnMock invocation ) throws Throwable - { - Object[] args = invocation.getArguments(); - plainIn = (ByteBuffer) args[1]; - ByteBuffer bytesDeciphered = createBufferWithContent( 6, 0, 6 ); - - // Simulating unwrap( cipherIn, plainIn ); - if ( plainIn.remaining() >= bytesDeciphered.capacity() ) - { - plainIn.put( bytesDeciphered ); - return new SSLEngineResult( OK, NOT_HANDSHAKING, 0, 0 ); - } - else - { - return new SSLEngineResult( BUFFER_OVERFLOW, NOT_HANDSHAKING, 0, 0 ); - } - } - } ).when( sslEngine ).unwrap( any( ByteBuffer.class ), any( ByteBuffer.class ) ); - - ByteBuffer twoByteBuffer = ByteBuffer.allocate( 2 ); - sslChannel.read( twoByteBuffer ); - sslChannel.read( twoByteBuffer ); - sslChannel.read( twoByteBuffer ); - - // Then - // Should enlarge plainIn buffer to hold all deciphered bytes - assertEquals( 8, plainIn.capacity() ); - ByteBuffer.allocate( 2 ).flip(); - TestCase.assertEquals( "00 01 ", BytePrinter.hex( twoByteBuffer ) ); - - // When trying to read 4 existing bytes and then 6 more bytes - ByteBuffer buffer = ByteBuffer.allocate( 10 ); - while (buffer.hasRemaining()) - { - sslChannel.read( buffer ); - } - // Then - // Should drain previous deciphered bytes first and then append new bytes after - assertEquals( 8, plainIn.capacity() ); - buffer.flip(); - assertEquals( "02 03 04 05 00 01 02 03 04 05 ", BytePrinter.hex( buffer ) ); - } - - @Test - public void shouldEnlargeNetworkInputBuffer() throws Throwable - { - // Given - bufferSize = 2; - cipherIn = ByteBuffer.allocate( bufferSize ); - plainIn = ByteBuffer.allocate( 1024 ); - ByteBuffer plainOut = mock( ByteBuffer.class ); - ByteBuffer cipherOut = mock( ByteBuffer.class ); - - SocketChannel channel = mock( SocketChannel.class ); - Logger logger = mock( Logger.class ); - - TLSSocketChannel sslChannel = - new TLSSocketChannel( channel, logger, sslEngine, plainIn, cipherIn, plainOut, cipherOut ); - - final ByteBuffer bytesFromChannel = createBufferWithContent( 6, 0, 6 ); - - // Simulating reading from channel and write into cipherIn - doAnswer( new Answer() - { - @Override - public Integer answer( InvocationOnMock invocation ) throws Throwable - { - Object[] args = invocation.getArguments(); - cipherIn = (ByteBuffer) args[0]; - return TLSSocketChannel.bufferCopy( bytesFromChannel, cipherIn ); - } - } ).when( channel ).read( any( ByteBuffer.class ) ); - - // Write 00 01 02 03 04 05 06 into plainIn, simulating deciphering some bytes - doAnswer( new Answer() - { - @Override - public SSLEngineResult answer( InvocationOnMock invocation ) throws Throwable - { - Object[] args = invocation.getArguments(); - plainIn = (ByteBuffer) args[1]; - - // Simulating unwrap( cipherIn, plainIn ); - if ( cipherIn.remaining() >= bytesFromChannel.capacity() ) - { - ByteBuffer bytesDeciphered = createBufferWithContent( 6, 0, 6 ); - plainIn.put( bytesDeciphered ); - cipherIn.position( cipherIn.limit() ); - return new SSLEngineResult( OK, NOT_HANDSHAKING, 0, 0 ); - } - else - { - return new SSLEngineResult( BUFFER_UNDERFLOW, NOT_HANDSHAKING, 0, 0 ); - } - } - } ).when( sslEngine ).unwrap( any( ByteBuffer.class ), any( ByteBuffer.class ) ); - - - //When - sslChannel.read( ByteBuffer.allocate( 2 ) ); - sslChannel.read( ByteBuffer.allocate( 2 ) ); - sslChannel.read( ByteBuffer.allocate( 2 ) ); - - // Then - assertEquals( 8, cipherIn.capacity() ); - assertEquals( "00 01 02 03 04 05 00 00 ", BytePrinter.hex( cipherIn ) ); - } - - @Test - public void shouldCompactNetworkInputBufferBeforeReadingMoreFromChannel() throws Throwable - { - // Given - bufferSize = 8; - cipherIn = ByteBuffer.allocate( bufferSize ); - plainIn = ByteBuffer.allocate( 1024 ); - ByteBuffer plainOut = mock( ByteBuffer.class ); - ByteBuffer cipherOut = mock( ByteBuffer.class ); - - SocketChannel channel = mock( SocketChannel.class ); - Logger logger = mock( Logger.class ); - - TLSSocketChannel sslChannel = - new TLSSocketChannel( channel, logger, sslEngine, plainIn, cipherIn, plainOut, cipherOut ); - - - // Simulate reading from channel and write into cipherIn - doAnswer( new Answer() - { - @Override - public Integer answer( InvocationOnMock invocation ) throws Throwable - { - Object[] args = invocation.getArguments(); - cipherIn = (ByteBuffer) args[0]; - ByteBuffer bytesFromChannel = createBufferWithContent( 4, 0, 4 ); // write 00 01 02 03 into cipherIn - TLSSocketChannel.bufferCopy( bytesFromChannel, cipherIn ); - return cipherIn.position(); - } - } ).when( channel ).read( any( ByteBuffer.class ) ); - - - final int[] rounds = {0}; // A counter for how many times we've entered unwrap method - // Simulating deciphering some bytes. - // Unfortunately we cannot simply treat unwrap as a black box this time - doAnswer( new Answer() - { - @Override - public SSLEngineResult answer( InvocationOnMock invocation ) throws Throwable - { - rounds[0]++; - Object[] args = invocation.getArguments(); - cipherIn = (ByteBuffer) args[0]; - plainIn = (ByteBuffer) args[1]; - - switch ( rounds[0] ) - { - case 1: - // 00 01 02 03 -> [XX XX] 02 03 - cipherIn.position( 2 ); // consume the first 2 bytes and re-enter unwrap - return new SSLEngineResult( OK, NOT_HANDSHAKING, 0, 0 ); - case 2: - // [XX XX] 02 03 -> 02 03 - bufferSize = bufferSize / 2; // so that we could value the same size back - return new SSLEngineResult( BUFFER_UNDERFLOW, NOT_HANDSHAKING, 0, 0 ); // waiting for more data - case 3: - // 02 03 00 01 02 03 -> [XX XX XX XX XX XX] - ByteBuffer bytesDeciphered = createBufferWithContent( 2, 0, 2 ); - plainIn.put( bytesDeciphered ); - cipherIn.position( cipherIn.limit() ); // consume all bytes in cipherIn to exit unwrap - return new SSLEngineResult( OK, NOT_HANDSHAKING, 0, 0 ); - default: - fail( "Should not call unwrap after all the bytes in cipherIn already consumed and OK returned" ); - return null; - } - - } - } ).when( sslEngine ).unwrap( any( ByteBuffer.class ), any( ByteBuffer.class ) ); - - //When - sslChannel.read( ByteBuffer.allocate( 8 ) ); - sslChannel.read( ByteBuffer.allocate( 8 ) ); - - // Then - assertEquals( 8, cipherIn.capacity() ); - assertEquals( "02 03 00 01 02 03 00 00 ", BytePrinter.hex( cipherIn ) ); - } - - @Test - public void shouldEnlargeNetworkOutputBuffer() throws Throwable - { - // Given - bufferSize = 2; - ByteBuffer cipherIn = mock( ByteBuffer.class ); - ByteBuffer plainIn = mock( ByteBuffer.class ); - ByteBuffer plainOut = mock( ByteBuffer.class ); - cipherOut = ByteBuffer.allocate( bufferSize ); - - final ByteBuffer buffer = ByteBuffer.allocate( 2 ); - - SocketChannel channel = mock( SocketChannel.class ); - Logger logger = mock( Logger.class ); - - TLSSocketChannel sslChannel = - new TLSSocketChannel( channel, logger, sslEngine, plainIn, cipherIn, plainOut, cipherOut ); - - - // Simulating encrypting some bytes - doAnswer( new Answer() - { - @Override - public SSLEngineResult answer( InvocationOnMock invocation ) throws Throwable - { - Object[] args = invocation.getArguments(); - cipherOut = (ByteBuffer) args[1]; - - // Simulating wrap( buffer, cipherIn ); - ByteBuffer bytesToChannel = createBufferWithContent( 6, 0, 6 ); - if ( cipherOut.remaining() >= bytesToChannel.capacity() ) - { - buffer.position( buffer.limit() ); - cipherOut.put( bytesToChannel ); - return new SSLEngineResult( OK, NOT_HANDSHAKING, 0, 0 ); - } - else - { - return new SSLEngineResult( BUFFER_OVERFLOW, NOT_HANDSHAKING, 0, 0 ); - } - - } - } ).when( sslEngine ).wrap( any( ByteBuffer.class ), any( ByteBuffer.class ) ); - - //When - sslChannel.write( buffer ); - - // Then - assertEquals( 8, cipherOut.capacity() ); - assertEquals( "00 01 02 03 04 05 00 00 ", BytePrinter.hex( cipherOut ) ); - } - - private static ByteBuffer createBufferWithContent( int size, int contentStartPos, int contentLength ) - { - ByteBuffer buffer = ByteBuffer.allocate( size ); - - buffer.position( contentStartPos ); - buffer.limit( contentLength + contentStartPos ); - - for ( int i = 0; i < contentLength; i++ ) - { - buffer.put( i + contentStartPos, (byte) i ); - } - - return buffer; - } -} diff --git a/driver/src/test/java/org/neo4j/driver/v1/integration/LoadCSVIT.java b/driver/src/test/java/org/neo4j/driver/v1/integration/LoadCSVIT.java index 807e8403e2..668fc7adcb 100644 --- a/driver/src/test/java/org/neo4j/driver/v1/integration/LoadCSVIT.java +++ b/driver/src/test/java/org/neo4j/driver/v1/integration/LoadCSVIT.java @@ -18,7 +18,6 @@ */ package org.neo4j.driver.v1.integration; -import org.junit.Ignore; import org.junit.Rule; import org.junit.Test; @@ -36,7 +35,6 @@ import static org.junit.Assert.assertFalse; import static org.neo4j.driver.v1.Values.parameters; -@Ignore public class LoadCSVIT { @Rule diff --git a/driver/src/test/java/org/neo4j/driver/v1/integration/TLSSocketChannelFragmentationIT.java b/driver/src/test/java/org/neo4j/driver/v1/integration/TLSSocketChannelFragmentationIT.java new file mode 100644 index 0000000000..5773ac3abd --- /dev/null +++ b/driver/src/test/java/org/neo4j/driver/v1/integration/TLSSocketChannelFragmentationIT.java @@ -0,0 +1,249 @@ +/** + * Copyright (c) 2002-2016 "Neo Technology," + * Network Engine for Objects in Lund AB [http://neotechnology.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.v1.integration; + +import org.junit.Before; +import org.junit.Test; + +import java.io.IOException; +import java.io.OutputStream; +import java.net.InetSocketAddress; +import java.net.ServerSocket; +import java.net.Socket; +import java.nio.ByteBuffer; +import java.nio.channels.ByteChannel; +import java.nio.channels.SocketChannel; +import java.security.GeneralSecurityException; +import java.security.KeyManagementException; +import java.security.KeyStore; +import java.security.KeyStoreException; +import java.security.NoSuchAlgorithmException; +import java.security.UnrecoverableKeyException; +import java.security.cert.CertificateException; +import java.security.cert.X509Certificate; +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLServerSocketFactory; +import javax.net.ssl.TrustManager; +import javax.net.ssl.X509TrustManager; + +import org.neo4j.driver.internal.connector.socket.TLSSocketChannel; +import org.neo4j.driver.internal.logging.DevNullLogger; + +import static org.hamcrest.core.IsEqual.equalTo; +import static org.junit.Assert.assertThat; + +/** + * This tests that the TLSSocketChannel handles every combination of network buffer sizes that we + * can reasonably expect to see in the wild. It exhaustively tests power-of-two sizes up to 2^16 + * for the following variables: + * + * - Network frame size + * - Bolt message size + * - Read buffer size + * + * It tests every possible combination, and it does this currently only for the read path, expanding + * to the write path as well would be useful. For each size, it sets up a TLS server and tests the + * handshake, transferring the data, and verifying the data is correct after decryption. + */ +public class TLSSocketChannelFragmentationIT +{ + private SSLContext sslCtx; + private byte[] blobOfData; + private ServerSocket server; + + @Before + public void setup() throws Throwable + { + createSSLContext(); + createServer(); + } + + private void blobOfDataSize( int dataBlobSize ) + { + blobOfData = new byte[dataBlobSize]; + // If the blob is all zeros, we'd miss data corruption problems in assertions, so + // fill the data blob with different values. + for ( int i = 0; i < blobOfData.length; i++ ) + { + blobOfData[i] = (byte) (i % 128); + } + } + + @Test + public void shouldHandleFuzziness() throws Throwable + { + // Given + int networkFrameSize, userBufferSize, blobOfDataSize; + + for(int dataBlobMagnitude = 1; dataBlobMagnitude < 16; dataBlobMagnitude+=2 ) + { + blobOfDataSize = (int) Math.pow( 2, dataBlobMagnitude ); + + for ( int frameSizeMagnitude = 1; frameSizeMagnitude < 16; frameSizeMagnitude+=2 ) + { + networkFrameSize = (int) Math.pow( 2, frameSizeMagnitude ); + for ( int userBufferMagnitude = 1; userBufferMagnitude < 16; userBufferMagnitude+=2 ) + { + userBufferSize = (int) Math.pow( 2, userBufferMagnitude ); + testForBufferSizes( blobOfDataSize, networkFrameSize, userBufferSize ); + } + } + } + } + + private void testForBufferSizes( int blobOfDataSize, int networkFrameSize, int userBufferSize ) throws IOException, GeneralSecurityException + { + blobOfDataSize(blobOfDataSize); + SSLEngine engine = sslCtx.createSSLEngine(); + engine.setUseClientMode( true ); + ByteChannel ch = SocketChannel.open( new InetSocketAddress( server.getInetAddress(), server.getLocalPort() ) ); + ch = new LittleAtATimeChannel( ch, networkFrameSize ); + + TLSSocketChannel channel = new TLSSocketChannel(ch, new DevNullLogger(), engine); + try + { + ByteBuffer readBuffer = ByteBuffer.allocate( blobOfData.length ); + while ( readBuffer.position() < readBuffer.capacity() ) + { + readBuffer.limit(Math.min( readBuffer.capacity(), readBuffer.position() + userBufferSize )); + channel.read( readBuffer ); + } + + assertThat(readBuffer.array(), equalTo(blobOfData)); + } + finally + { + channel.close(); + } + } + + private void createSSLContext() + throws KeyStoreException, IOException, NoSuchAlgorithmException, CertificateException, UnrecoverableKeyException, KeyManagementException + { + KeyStore ks = KeyStore.getInstance("JKS"); + char[] password = "password".toCharArray(); + ks.load( getClass().getResourceAsStream( "/keystore.jks" ), password ); + KeyManagerFactory kmf = KeyManagerFactory.getInstance("SunX509"); + kmf.init(ks, password); + + sslCtx = SSLContext.getInstance("TLS"); + sslCtx.init( kmf.getKeyManagers(), new TrustManager[]{new X509TrustManager() { + public void checkClientTrusted( X509Certificate[] chain, String authType) throws CertificateException + { + } + + public void checkServerTrusted(X509Certificate[] chain, String authType) throws CertificateException { + } + + public X509Certificate[] getAcceptedIssuers() { + return null; + } + }}, null ); + } + + private void createServer() throws IOException + { + SSLServerSocketFactory ssf = sslCtx.getServerSocketFactory(); + server = ssf.createServerSocket(0); + + new Thread(new Runnable() + { + @Override + public void run() + { + try + { + while(true) + { + Socket client = server.accept(); + OutputStream outputStream = client.getOutputStream(); + outputStream.write( blobOfData ); + outputStream.flush(); + // client.close(); // TODO: Uncomment this, fix resulting error handling CLOSED event + } + } + catch ( IOException e ) + { + e.printStackTrace(); + } + } + }).start(); + } + + /** + * Delegates to underlying channel, but only reads up to the set amount at a time, used to emulate + * different network frame sizes in this test. + */ + private static class LittleAtATimeChannel implements ByteChannel + { + private final ByteChannel delegate; + private final int maxFrameSize; + + public LittleAtATimeChannel( ByteChannel delegate, int maxFrameSize ) + { + + this.delegate = delegate; + this.maxFrameSize = maxFrameSize; + } + + @Override + public boolean isOpen() + { + return delegate.isOpen(); + } + + @Override + public void close() throws IOException + { + delegate.close(); + } + + @Override + public int write( ByteBuffer src ) throws IOException + { + int originalLimit = src.limit(); + try + { + src.limit( Math.min( src.limit(), src.position() + maxFrameSize ) ); + return delegate.write( src ); + } + finally + { + src.limit(originalLimit); + } + } + + @Override + public int read( ByteBuffer dst ) throws IOException + { + int originalLimit = dst.limit(); + try + { + dst.limit( Math.min( dst.limit(), dst.position() + maxFrameSize ) ); + return delegate.read( dst ); + } + finally + { + dst.limit(originalLimit); + } + } + } +} diff --git a/driver/src/test/java/org/neo4j/driver/v1/integration/TLSSocketChannelIT.java b/driver/src/test/java/org/neo4j/driver/v1/integration/TLSSocketChannelIT.java index e851e14008..60b2aa39ea 100644 --- a/driver/src/test/java/org/neo4j/driver/v1/integration/TLSSocketChannelIT.java +++ b/driver/src/test/java/org/neo4j/driver/v1/integration/TLSSocketChannelIT.java @@ -91,8 +91,6 @@ private void performTLSHandshakeUsingKnownCerts( File knownCerts ) throws Throwa sslChannel.close(); // Then - verify( logger, atLeastOnce() ).debug( "TLS connection enabled" ); - verify( logger, atLeastOnce() ).debug( "TLS connection established" ); verify( logger, atLeastOnce() ).debug( "TLS connection closed" ); } @@ -135,8 +133,6 @@ public void shouldPerformTLSHandshakeWithTrustedCert() throws Throwable sslChannel.close(); // Then - verify( logger, atLeastOnce() ).debug( "TLS connection enabled" ); - verify( logger, atLeastOnce() ).debug( "TLS connection established" ); verify( logger, atLeastOnce() ).debug( "TLS connection closed" ); } finally @@ -249,8 +245,6 @@ public void shouldPerformTLSHandshakeWithTheSameTrustedServerCert() throws Throw sslChannel.close(); // Then - verify( logger, atLeastOnce() ).debug( "TLS connection enabled" ); - verify( logger, atLeastOnce() ).debug( "TLS connection established" ); verify( logger, atLeastOnce() ).debug( "TLS connection closed" ); } diff --git a/driver/src/test/resources/keystore.jks b/driver/src/test/resources/keystore.jks new file mode 100644 index 0000000000..59beefc44a Binary files /dev/null and b/driver/src/test/resources/keystore.jks differ