From 54610a418f7ccedd0194a69e7dee1107d2651088 Mon Sep 17 00:00:00 2001 From: Jacob Hansson Date: Wed, 20 Apr 2016 16:53:27 +0200 Subject: [PATCH] Resolve deadlock issue in TLSSocketChannel This introduces a comprehensive test suite for testing thousands of different buffer size, network frame size and message size combinations for the TLSSocketChannel. This test uncovered several issues, notably it found a blocking issue that is resolved in this commit as well, where the hand shake portion of the exchange would also read a bit "into" user data space, and that data would then get lost - subsequently, when the client layer asked for the amount of data it "expected" to be in place, it'd block indefinitely because of that lost data ensuring we'd never add up to the expected amount. Notably, this commit also removes four unit tests. These tests are covered by the more comprehensive test, and were rather mock and stub heavy, meaning it seemed preferrable to remove them than to refactor them to work with the changes to the implementation --- .../connector/socket/SocketClient.java | 1 + .../connector/socket/TLSSocketChannel.java | 153 ++++---- .../socket/TLSSocketChannelTest.java | 370 ------------------ .../driver/v1/integration/LoadCSVIT.java | 2 - .../TLSSocketChannelFragmentationIT.java | 249 ++++++++++++ .../v1/integration/TLSSocketChannelIT.java | 6 - driver/src/test/resources/keystore.jks | Bin 0 -> 2360 bytes 7 files changed, 318 insertions(+), 463 deletions(-) delete mode 100644 driver/src/test/java/org/neo4j/driver/internal/connector/socket/TLSSocketChannelTest.java create mode 100644 driver/src/test/java/org/neo4j/driver/v1/integration/TLSSocketChannelFragmentationIT.java create mode 100644 driver/src/test/resources/keystore.jks diff --git a/driver/src/main/java/org/neo4j/driver/internal/connector/socket/SocketClient.java b/driver/src/main/java/org/neo4j/driver/internal/connector/socket/SocketClient.java index 41aee7d6d9..d643f50000 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/connector/socket/SocketClient.java +++ b/driver/src/main/java/org/neo4j/driver/internal/connector/socket/SocketClient.java @@ -215,6 +215,7 @@ public static ByteChannel create( String host, int port, Config config, Logger l SocketChannel soChannel = SocketChannel.open(); soChannel.setOption( StandardSocketOptions.SO_REUSEADDR, true ); soChannel.setOption( StandardSocketOptions.SO_KEEPALIVE, true ); + soChannel.connect( new InetSocketAddress( host, port ) ); ByteChannel channel; diff --git a/driver/src/main/java/org/neo4j/driver/internal/connector/socket/TLSSocketChannel.java b/driver/src/main/java/org/neo4j/driver/internal/connector/socket/TLSSocketChannel.java index 09d1c773df..a370ef4765 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/connector/socket/TLSSocketChannel.java +++ b/driver/src/main/java/org/neo4j/driver/internal/connector/socket/TLSSocketChannel.java @@ -21,14 +21,12 @@ import java.io.IOException; import java.nio.ByteBuffer; import java.nio.channels.ByteChannel; -import java.nio.channels.SocketChannel; import java.security.GeneralSecurityException; import javax.net.ssl.SSLContext; import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLEngineResult; import javax.net.ssl.SSLEngineResult.HandshakeStatus; import javax.net.ssl.SSLEngineResult.Status; -import javax.net.ssl.SSLSession; import org.neo4j.driver.internal.spi.Logger; import org.neo4j.driver.internal.util.BytePrinter; @@ -51,10 +49,9 @@ */ public class TLSSocketChannel implements ByteChannel { - private final SocketChannel channel; // The real channel the data is sent to and read from + private final ByteChannel channel; // The real channel the data is sent to and read from private final Logger logger; - private final SSLContext sslContext; private SSLEngine sslEngine; /** The buffer for network data */ @@ -66,34 +63,36 @@ public class TLSSocketChannel implements ByteChannel private static final ByteBuffer DUMMY_BUFFER = ByteBuffer.allocateDirect( 0 ); - public TLSSocketChannel( String host, int port, SocketChannel channel, Logger logger, + public TLSSocketChannel( String host, int port, ByteChannel channel, Logger logger, TrustStrategy trustStrategy ) throws GeneralSecurityException, IOException { - logger.debug( "TLS connection enabled" ); - this.logger = logger; - this.channel = channel; - this.channel.configureBlocking( true ); + this(channel, logger, + createSSLEngine( host, port, new SSLContextFactory( host, port, trustStrategy, logger ).create() ) ); - sslContext = new SSLContextFactory( host, port, trustStrategy, logger ).create(); - createSSLEngine( host, port ); - createBuffers(); - runHandshake(); - logger.debug( "TLS connection established" ); } - /** Used in internal tests only */ - TLSSocketChannel( SocketChannel channel, Logger logger, SSLEngine sslEngine, + public TLSSocketChannel( ByteChannel channel, Logger logger, SSLEngine sslEngine ) throws GeneralSecurityException, IOException + { + this(channel, logger, sslEngine, + ByteBuffer.allocateDirect( sslEngine.getSession().getApplicationBufferSize() ), + ByteBuffer.allocateDirect( sslEngine.getSession().getPacketBufferSize() ), + ByteBuffer.allocateDirect( sslEngine.getSession().getApplicationBufferSize() ), + ByteBuffer.allocateDirect( sslEngine.getSession().getPacketBufferSize() ) ); + } + + TLSSocketChannel( ByteChannel channel, Logger logger, SSLEngine sslEngine, ByteBuffer plainIn, ByteBuffer cipherIn, ByteBuffer plainOut, ByteBuffer cipherOut ) throws GeneralSecurityException, IOException { - logger.debug( "Testing TLS buffers" ); this.logger = logger; this.channel = channel; - - this.sslContext = SSLContext.getInstance( "TLS" ); this.sslEngine = sslEngine; - resetBuffers( plainIn, cipherIn, plainOut, cipherOut ); // reset buffer size + this.plainIn = plainIn; + this.cipherIn = cipherIn; + this.plainOut = plainOut; + this.cipherOut = cipherOut; + runHandshake(); } /** @@ -126,7 +125,6 @@ private void runHandshake() throws IOException case NEED_UNWRAP: // Unwrap the ssl packet to value ssl handshake information handshakeStatus = unwrap( DUMMY_BUFFER ); - plainIn.clear(); break; case NEED_WRAP: // Wrap the app packet into an ssl packet to add ssl handshake information @@ -134,9 +132,6 @@ private void runHandshake() throws IOException break; } } - - plainIn.clear(); - plainOut.clear(); } private HandshakeStatus runDelegatedTasks() @@ -185,10 +180,11 @@ private HandshakeStatus unwrap( ByteBuffer buffer ) throws IOException } cipherIn.flip(); - Status status = null; + Status status; do { - status = sslEngine.unwrap( cipherIn, plainIn ).getStatus(); + SSLEngineResult unwrapResult = sslEngine.unwrap( cipherIn, plainIn ); + status = unwrapResult.getStatus(); // Possible status here: // OK - good // BUFFER_OVERFLOW - we need to enlarge* plainIn @@ -244,17 +240,13 @@ private HandshakeStatus unwrap( ByteBuffer buffer ) throws IOException // Otherwise, make room for reading more data from channel cipherIn.compact(); } - - // I skipped the following check as it "should not" happen at all: - // The channel should not provide us ciphered bytes that cannot hold in the channel buffer at all - // if( cipherIn.remaining() == 0 ) - // {throw new ClientException( "cannot enlarge as it already reached the limit" );} - - // Obtain more inbound network data for cipherIn, - // then retry the operation. return handshakeStatus; // old status + case CLOSED: + // RFC 2246 #7.2.1 requires us to stop accepting input. + sslEngine.closeInbound(); + break; default: - throw new ClientException( "Got unexpected status " + status ); + throw new ClientException( "Got unexpected status " + status + ", " + unwrapResult ); } } while ( cipherIn.hasRemaining() ); /* Remember we are doing blocking reading. @@ -285,7 +277,10 @@ private HandshakeStatus wrap( ByteBuffer buffer ) throws IOException case OK: handshakeStatus = runDelegatedTasks(); cipherOut.flip(); - channel.write( cipherOut ); + while(cipherOut.hasRemaining()) + { + channel.write( cipherOut ); + } cipherOut.clear(); break; case BUFFER_OVERFLOW: @@ -344,42 +339,17 @@ static int bufferCopy( ByteBuffer from, ByteBuffer to ) return maxTransfer; } - /** - * Create network buffers and application buffers - * - * @throws IOException - */ - private void createBuffers() throws IOException - { - SSLSession session = sslEngine.getSession(); - int appBufferSize = session.getApplicationBufferSize(); - int netBufferSize = session.getPacketBufferSize(); - - plainOut = ByteBuffer.allocateDirect( appBufferSize ); - plainIn = ByteBuffer.allocateDirect( appBufferSize ); - cipherOut = ByteBuffer.allocateDirect( netBufferSize ); - cipherIn = ByteBuffer.allocateDirect( netBufferSize ); - } - - /** Should only be used in tests */ - void resetBuffers( ByteBuffer plainIn, ByteBuffer cipherIn, ByteBuffer plainOut, ByteBuffer cipherOut ) - { - this.plainIn = plainIn; - this.cipherIn = cipherIn; - this.plainOut = plainOut; - this.cipherOut = cipherOut; - } - /** * Create SSLEngine with the SSLContext just created. - * * @param host * @param port + * @param sslContext */ - private void createSSLEngine( String host, int port ) + private static SSLEngine createSSLEngine( String host, int port, SSLContext sslContext ) { - sslEngine = sslContext.createSSLEngine( host, port ); + SSLEngine sslEngine = sslContext.createSSLEngine( host, port ); sslEngine.setUseClientMode( true ); + return sslEngine; } @Override @@ -431,33 +401,46 @@ public boolean isOpen() @Override public void close() throws IOException { - plainOut.clear(); - // Indicate that application is done with engine - sslEngine.closeOutbound(); - - while ( !sslEngine.isOutboundDone() ) + try { - // Get close message - SSLEngineResult res = sslEngine.wrap( plainOut, cipherOut ); - - // Check res statuses + plainOut.clear(); + // Indicate that application is done with engine + sslEngine.closeOutbound(); - // Send close message to peer - cipherOut.flip(); - while ( cipherOut.hasRemaining() ) + while ( !sslEngine.isOutboundDone() ) { - int num = channel.write( cipherOut ); - if ( num == -1 ) + // Get close message + SSLEngineResult res = sslEngine.wrap( plainOut, cipherOut ); + + // Check res statuses + + // Send close message to peer + cipherOut.flip(); + while ( cipherOut.hasRemaining() ) { - // handle closed channel - break; + int num = channel.write( cipherOut ); + if ( num == -1 ) + { + // handle closed channel + break; + } } + cipherOut.clear(); } - cipherOut.clear(); + // Close transport + channel.close(); + logger.debug( "TLS connection closed" ); + } + catch(IOException e) + { + // Treat this as ok - the connection is closed, even if the TLS session did not exit cleanly. + logger.warn( "TLS socket could not be closed cleanly: '"+e.getMessage()+"'", e ); } - // Close transport - channel.close(); - logger.debug( "TLS connection closed" ); } + @Override + public String toString() + { + return "TLSSocketChannel{plainIn: " + plainIn + ", cipherIn:" + cipherIn + "}"; + } } diff --git a/driver/src/test/java/org/neo4j/driver/internal/connector/socket/TLSSocketChannelTest.java b/driver/src/test/java/org/neo4j/driver/internal/connector/socket/TLSSocketChannelTest.java deleted file mode 100644 index 7d155eee4d..0000000000 --- a/driver/src/test/java/org/neo4j/driver/internal/connector/socket/TLSSocketChannelTest.java +++ /dev/null @@ -1,370 +0,0 @@ -/** - * Copyright (c) 2002-2016 "Neo Technology," - * Network Engine for Objects in Lund AB [http://neotechnology.com] - * - * This file is part of Neo4j. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.neo4j.driver.internal.connector.socket; - -import junit.framework.TestCase; -import org.junit.BeforeClass; -import org.junit.Test; -import org.mockito.invocation.InvocationOnMock; -import org.mockito.stubbing.Answer; - -import java.nio.ByteBuffer; -import java.nio.channels.SocketChannel; -import javax.net.ssl.SSLEngine; -import javax.net.ssl.SSLEngineResult; -import javax.net.ssl.SSLSession; - -import org.neo4j.driver.internal.spi.Logger; -import org.neo4j.driver.internal.util.BytePrinter; - -import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING; -import static javax.net.ssl.SSLEngineResult.Status.BUFFER_OVERFLOW; -import static javax.net.ssl.SSLEngineResult.Status.BUFFER_UNDERFLOW; -import static javax.net.ssl.SSLEngineResult.Status.OK; -import static junit.framework.TestCase.assertEquals; -import static junit.framework.TestCase.fail; -import static org.mockito.Matchers.any; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; - -/** - * Tests related to the buffer uses in SSLSocketChannel - */ -public class TLSSocketChannelTest -{ - private ByteBuffer plainIn; - private ByteBuffer cipherIn; - private ByteBuffer cipherOut; - - private static int bufferSize; - private static SSLEngine sslEngine; - - @BeforeClass - public static void setup() - { - sslEngine = mock( SSLEngine.class ); - SSLSession session = mock( SSLSession.class ); - when( sslEngine.getSession() ).thenReturn( session ); - - // The strategy to enlarge the application buffer: double the size - doAnswer( new Answer() - { - @Override - public Integer answer( InvocationOnMock invocation ) throws Throwable - { - bufferSize *= 2; - if ( bufferSize > 8 ) - { - fail( "We do not need a application buffer greater than 8 for all the SSL buffer tests" ); - } - return bufferSize; - } - } ).when( session ).getApplicationBufferSize(); - - // The strategy to enlarge the network buffer: double the size - doAnswer( new Answer() - { - @Override - public Integer answer( InvocationOnMock invocation ) throws Throwable - { - bufferSize *= 2; - if ( bufferSize > 8 ) - { - fail( "We do not need a network buffer greater than 8 for all the SSL buffer tests" ); - } - return bufferSize; - } - } ).when( session ).getPacketBufferSize(); - } - - @Test - public void shouldEnlargeApplicationInputBuffer() throws Throwable - { - // Given - bufferSize = 2; - plainIn = ByteBuffer.allocate( bufferSize ); - ByteBuffer cipherIn = mock( ByteBuffer.class ); - ByteBuffer plainOut = mock( ByteBuffer.class ); - ByteBuffer cipherOut = mock( ByteBuffer.class ); - - SocketChannel channel = mock( SocketChannel.class ); - Logger logger = mock( Logger.class ); - - - TLSSocketChannel sslChannel = - new TLSSocketChannel( channel, logger, sslEngine, plainIn, cipherIn, plainOut, cipherOut ); - - // Write 00 01 02 03 04 05 06 into plainIn, simulating deciphering some bytes - doAnswer( new Answer() - { - @Override - public SSLEngineResult answer( InvocationOnMock invocation ) throws Throwable - { - Object[] args = invocation.getArguments(); - plainIn = (ByteBuffer) args[1]; - ByteBuffer bytesDeciphered = createBufferWithContent( 6, 0, 6 ); - - // Simulating unwrap( cipherIn, plainIn ); - if ( plainIn.remaining() >= bytesDeciphered.capacity() ) - { - plainIn.put( bytesDeciphered ); - return new SSLEngineResult( OK, NOT_HANDSHAKING, 0, 0 ); - } - else - { - return new SSLEngineResult( BUFFER_OVERFLOW, NOT_HANDSHAKING, 0, 0 ); - } - } - } ).when( sslEngine ).unwrap( any( ByteBuffer.class ), any( ByteBuffer.class ) ); - - ByteBuffer twoByteBuffer = ByteBuffer.allocate( 2 ); - sslChannel.read( twoByteBuffer ); - sslChannel.read( twoByteBuffer ); - sslChannel.read( twoByteBuffer ); - - // Then - // Should enlarge plainIn buffer to hold all deciphered bytes - assertEquals( 8, plainIn.capacity() ); - ByteBuffer.allocate( 2 ).flip(); - TestCase.assertEquals( "00 01 ", BytePrinter.hex( twoByteBuffer ) ); - - // When trying to read 4 existing bytes and then 6 more bytes - ByteBuffer buffer = ByteBuffer.allocate( 10 ); - while (buffer.hasRemaining()) - { - sslChannel.read( buffer ); - } - // Then - // Should drain previous deciphered bytes first and then append new bytes after - assertEquals( 8, plainIn.capacity() ); - buffer.flip(); - assertEquals( "02 03 04 05 00 01 02 03 04 05 ", BytePrinter.hex( buffer ) ); - } - - @Test - public void shouldEnlargeNetworkInputBuffer() throws Throwable - { - // Given - bufferSize = 2; - cipherIn = ByteBuffer.allocate( bufferSize ); - plainIn = ByteBuffer.allocate( 1024 ); - ByteBuffer plainOut = mock( ByteBuffer.class ); - ByteBuffer cipherOut = mock( ByteBuffer.class ); - - SocketChannel channel = mock( SocketChannel.class ); - Logger logger = mock( Logger.class ); - - TLSSocketChannel sslChannel = - new TLSSocketChannel( channel, logger, sslEngine, plainIn, cipherIn, plainOut, cipherOut ); - - final ByteBuffer bytesFromChannel = createBufferWithContent( 6, 0, 6 ); - - // Simulating reading from channel and write into cipherIn - doAnswer( new Answer() - { - @Override - public Integer answer( InvocationOnMock invocation ) throws Throwable - { - Object[] args = invocation.getArguments(); - cipherIn = (ByteBuffer) args[0]; - return TLSSocketChannel.bufferCopy( bytesFromChannel, cipherIn ); - } - } ).when( channel ).read( any( ByteBuffer.class ) ); - - // Write 00 01 02 03 04 05 06 into plainIn, simulating deciphering some bytes - doAnswer( new Answer() - { - @Override - public SSLEngineResult answer( InvocationOnMock invocation ) throws Throwable - { - Object[] args = invocation.getArguments(); - plainIn = (ByteBuffer) args[1]; - - // Simulating unwrap( cipherIn, plainIn ); - if ( cipherIn.remaining() >= bytesFromChannel.capacity() ) - { - ByteBuffer bytesDeciphered = createBufferWithContent( 6, 0, 6 ); - plainIn.put( bytesDeciphered ); - cipherIn.position( cipherIn.limit() ); - return new SSLEngineResult( OK, NOT_HANDSHAKING, 0, 0 ); - } - else - { - return new SSLEngineResult( BUFFER_UNDERFLOW, NOT_HANDSHAKING, 0, 0 ); - } - } - } ).when( sslEngine ).unwrap( any( ByteBuffer.class ), any( ByteBuffer.class ) ); - - - //When - sslChannel.read( ByteBuffer.allocate( 2 ) ); - sslChannel.read( ByteBuffer.allocate( 2 ) ); - sslChannel.read( ByteBuffer.allocate( 2 ) ); - - // Then - assertEquals( 8, cipherIn.capacity() ); - assertEquals( "00 01 02 03 04 05 00 00 ", BytePrinter.hex( cipherIn ) ); - } - - @Test - public void shouldCompactNetworkInputBufferBeforeReadingMoreFromChannel() throws Throwable - { - // Given - bufferSize = 8; - cipherIn = ByteBuffer.allocate( bufferSize ); - plainIn = ByteBuffer.allocate( 1024 ); - ByteBuffer plainOut = mock( ByteBuffer.class ); - ByteBuffer cipherOut = mock( ByteBuffer.class ); - - SocketChannel channel = mock( SocketChannel.class ); - Logger logger = mock( Logger.class ); - - TLSSocketChannel sslChannel = - new TLSSocketChannel( channel, logger, sslEngine, plainIn, cipherIn, plainOut, cipherOut ); - - - // Simulate reading from channel and write into cipherIn - doAnswer( new Answer() - { - @Override - public Integer answer( InvocationOnMock invocation ) throws Throwable - { - Object[] args = invocation.getArguments(); - cipherIn = (ByteBuffer) args[0]; - ByteBuffer bytesFromChannel = createBufferWithContent( 4, 0, 4 ); // write 00 01 02 03 into cipherIn - TLSSocketChannel.bufferCopy( bytesFromChannel, cipherIn ); - return cipherIn.position(); - } - } ).when( channel ).read( any( ByteBuffer.class ) ); - - - final int[] rounds = {0}; // A counter for how many times we've entered unwrap method - // Simulating deciphering some bytes. - // Unfortunately we cannot simply treat unwrap as a black box this time - doAnswer( new Answer() - { - @Override - public SSLEngineResult answer( InvocationOnMock invocation ) throws Throwable - { - rounds[0]++; - Object[] args = invocation.getArguments(); - cipherIn = (ByteBuffer) args[0]; - plainIn = (ByteBuffer) args[1]; - - switch ( rounds[0] ) - { - case 1: - // 00 01 02 03 -> [XX XX] 02 03 - cipherIn.position( 2 ); // consume the first 2 bytes and re-enter unwrap - return new SSLEngineResult( OK, NOT_HANDSHAKING, 0, 0 ); - case 2: - // [XX XX] 02 03 -> 02 03 - bufferSize = bufferSize / 2; // so that we could value the same size back - return new SSLEngineResult( BUFFER_UNDERFLOW, NOT_HANDSHAKING, 0, 0 ); // waiting for more data - case 3: - // 02 03 00 01 02 03 -> [XX XX XX XX XX XX] - ByteBuffer bytesDeciphered = createBufferWithContent( 2, 0, 2 ); - plainIn.put( bytesDeciphered ); - cipherIn.position( cipherIn.limit() ); // consume all bytes in cipherIn to exit unwrap - return new SSLEngineResult( OK, NOT_HANDSHAKING, 0, 0 ); - default: - fail( "Should not call unwrap after all the bytes in cipherIn already consumed and OK returned" ); - return null; - } - - } - } ).when( sslEngine ).unwrap( any( ByteBuffer.class ), any( ByteBuffer.class ) ); - - //When - sslChannel.read( ByteBuffer.allocate( 8 ) ); - sslChannel.read( ByteBuffer.allocate( 8 ) ); - - // Then - assertEquals( 8, cipherIn.capacity() ); - assertEquals( "02 03 00 01 02 03 00 00 ", BytePrinter.hex( cipherIn ) ); - } - - @Test - public void shouldEnlargeNetworkOutputBuffer() throws Throwable - { - // Given - bufferSize = 2; - ByteBuffer cipherIn = mock( ByteBuffer.class ); - ByteBuffer plainIn = mock( ByteBuffer.class ); - ByteBuffer plainOut = mock( ByteBuffer.class ); - cipherOut = ByteBuffer.allocate( bufferSize ); - - final ByteBuffer buffer = ByteBuffer.allocate( 2 ); - - SocketChannel channel = mock( SocketChannel.class ); - Logger logger = mock( Logger.class ); - - TLSSocketChannel sslChannel = - new TLSSocketChannel( channel, logger, sslEngine, plainIn, cipherIn, plainOut, cipherOut ); - - - // Simulating encrypting some bytes - doAnswer( new Answer() - { - @Override - public SSLEngineResult answer( InvocationOnMock invocation ) throws Throwable - { - Object[] args = invocation.getArguments(); - cipherOut = (ByteBuffer) args[1]; - - // Simulating wrap( buffer, cipherIn ); - ByteBuffer bytesToChannel = createBufferWithContent( 6, 0, 6 ); - if ( cipherOut.remaining() >= bytesToChannel.capacity() ) - { - buffer.position( buffer.limit() ); - cipherOut.put( bytesToChannel ); - return new SSLEngineResult( OK, NOT_HANDSHAKING, 0, 0 ); - } - else - { - return new SSLEngineResult( BUFFER_OVERFLOW, NOT_HANDSHAKING, 0, 0 ); - } - - } - } ).when( sslEngine ).wrap( any( ByteBuffer.class ), any( ByteBuffer.class ) ); - - //When - sslChannel.write( buffer ); - - // Then - assertEquals( 8, cipherOut.capacity() ); - assertEquals( "00 01 02 03 04 05 00 00 ", BytePrinter.hex( cipherOut ) ); - } - - private static ByteBuffer createBufferWithContent( int size, int contentStartPos, int contentLength ) - { - ByteBuffer buffer = ByteBuffer.allocate( size ); - - buffer.position( contentStartPos ); - buffer.limit( contentLength + contentStartPos ); - - for ( int i = 0; i < contentLength; i++ ) - { - buffer.put( i + contentStartPos, (byte) i ); - } - - return buffer; - } -} diff --git a/driver/src/test/java/org/neo4j/driver/v1/integration/LoadCSVIT.java b/driver/src/test/java/org/neo4j/driver/v1/integration/LoadCSVIT.java index 807e8403e2..668fc7adcb 100644 --- a/driver/src/test/java/org/neo4j/driver/v1/integration/LoadCSVIT.java +++ b/driver/src/test/java/org/neo4j/driver/v1/integration/LoadCSVIT.java @@ -18,7 +18,6 @@ */ package org.neo4j.driver.v1.integration; -import org.junit.Ignore; import org.junit.Rule; import org.junit.Test; @@ -36,7 +35,6 @@ import static org.junit.Assert.assertFalse; import static org.neo4j.driver.v1.Values.parameters; -@Ignore public class LoadCSVIT { @Rule diff --git a/driver/src/test/java/org/neo4j/driver/v1/integration/TLSSocketChannelFragmentationIT.java b/driver/src/test/java/org/neo4j/driver/v1/integration/TLSSocketChannelFragmentationIT.java new file mode 100644 index 0000000000..5773ac3abd --- /dev/null +++ b/driver/src/test/java/org/neo4j/driver/v1/integration/TLSSocketChannelFragmentationIT.java @@ -0,0 +1,249 @@ +/** + * Copyright (c) 2002-2016 "Neo Technology," + * Network Engine for Objects in Lund AB [http://neotechnology.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.v1.integration; + +import org.junit.Before; +import org.junit.Test; + +import java.io.IOException; +import java.io.OutputStream; +import java.net.InetSocketAddress; +import java.net.ServerSocket; +import java.net.Socket; +import java.nio.ByteBuffer; +import java.nio.channels.ByteChannel; +import java.nio.channels.SocketChannel; +import java.security.GeneralSecurityException; +import java.security.KeyManagementException; +import java.security.KeyStore; +import java.security.KeyStoreException; +import java.security.NoSuchAlgorithmException; +import java.security.UnrecoverableKeyException; +import java.security.cert.CertificateException; +import java.security.cert.X509Certificate; +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLServerSocketFactory; +import javax.net.ssl.TrustManager; +import javax.net.ssl.X509TrustManager; + +import org.neo4j.driver.internal.connector.socket.TLSSocketChannel; +import org.neo4j.driver.internal.logging.DevNullLogger; + +import static org.hamcrest.core.IsEqual.equalTo; +import static org.junit.Assert.assertThat; + +/** + * This tests that the TLSSocketChannel handles every combination of network buffer sizes that we + * can reasonably expect to see in the wild. It exhaustively tests power-of-two sizes up to 2^16 + * for the following variables: + * + * - Network frame size + * - Bolt message size + * - Read buffer size + * + * It tests every possible combination, and it does this currently only for the read path, expanding + * to the write path as well would be useful. For each size, it sets up a TLS server and tests the + * handshake, transferring the data, and verifying the data is correct after decryption. + */ +public class TLSSocketChannelFragmentationIT +{ + private SSLContext sslCtx; + private byte[] blobOfData; + private ServerSocket server; + + @Before + public void setup() throws Throwable + { + createSSLContext(); + createServer(); + } + + private void blobOfDataSize( int dataBlobSize ) + { + blobOfData = new byte[dataBlobSize]; + // If the blob is all zeros, we'd miss data corruption problems in assertions, so + // fill the data blob with different values. + for ( int i = 0; i < blobOfData.length; i++ ) + { + blobOfData[i] = (byte) (i % 128); + } + } + + @Test + public void shouldHandleFuzziness() throws Throwable + { + // Given + int networkFrameSize, userBufferSize, blobOfDataSize; + + for(int dataBlobMagnitude = 1; dataBlobMagnitude < 16; dataBlobMagnitude+=2 ) + { + blobOfDataSize = (int) Math.pow( 2, dataBlobMagnitude ); + + for ( int frameSizeMagnitude = 1; frameSizeMagnitude < 16; frameSizeMagnitude+=2 ) + { + networkFrameSize = (int) Math.pow( 2, frameSizeMagnitude ); + for ( int userBufferMagnitude = 1; userBufferMagnitude < 16; userBufferMagnitude+=2 ) + { + userBufferSize = (int) Math.pow( 2, userBufferMagnitude ); + testForBufferSizes( blobOfDataSize, networkFrameSize, userBufferSize ); + } + } + } + } + + private void testForBufferSizes( int blobOfDataSize, int networkFrameSize, int userBufferSize ) throws IOException, GeneralSecurityException + { + blobOfDataSize(blobOfDataSize); + SSLEngine engine = sslCtx.createSSLEngine(); + engine.setUseClientMode( true ); + ByteChannel ch = SocketChannel.open( new InetSocketAddress( server.getInetAddress(), server.getLocalPort() ) ); + ch = new LittleAtATimeChannel( ch, networkFrameSize ); + + TLSSocketChannel channel = new TLSSocketChannel(ch, new DevNullLogger(), engine); + try + { + ByteBuffer readBuffer = ByteBuffer.allocate( blobOfData.length ); + while ( readBuffer.position() < readBuffer.capacity() ) + { + readBuffer.limit(Math.min( readBuffer.capacity(), readBuffer.position() + userBufferSize )); + channel.read( readBuffer ); + } + + assertThat(readBuffer.array(), equalTo(blobOfData)); + } + finally + { + channel.close(); + } + } + + private void createSSLContext() + throws KeyStoreException, IOException, NoSuchAlgorithmException, CertificateException, UnrecoverableKeyException, KeyManagementException + { + KeyStore ks = KeyStore.getInstance("JKS"); + char[] password = "password".toCharArray(); + ks.load( getClass().getResourceAsStream( "/keystore.jks" ), password ); + KeyManagerFactory kmf = KeyManagerFactory.getInstance("SunX509"); + kmf.init(ks, password); + + sslCtx = SSLContext.getInstance("TLS"); + sslCtx.init( kmf.getKeyManagers(), new TrustManager[]{new X509TrustManager() { + public void checkClientTrusted( X509Certificate[] chain, String authType) throws CertificateException + { + } + + public void checkServerTrusted(X509Certificate[] chain, String authType) throws CertificateException { + } + + public X509Certificate[] getAcceptedIssuers() { + return null; + } + }}, null ); + } + + private void createServer() throws IOException + { + SSLServerSocketFactory ssf = sslCtx.getServerSocketFactory(); + server = ssf.createServerSocket(0); + + new Thread(new Runnable() + { + @Override + public void run() + { + try + { + while(true) + { + Socket client = server.accept(); + OutputStream outputStream = client.getOutputStream(); + outputStream.write( blobOfData ); + outputStream.flush(); + // client.close(); // TODO: Uncomment this, fix resulting error handling CLOSED event + } + } + catch ( IOException e ) + { + e.printStackTrace(); + } + } + }).start(); + } + + /** + * Delegates to underlying channel, but only reads up to the set amount at a time, used to emulate + * different network frame sizes in this test. + */ + private static class LittleAtATimeChannel implements ByteChannel + { + private final ByteChannel delegate; + private final int maxFrameSize; + + public LittleAtATimeChannel( ByteChannel delegate, int maxFrameSize ) + { + + this.delegate = delegate; + this.maxFrameSize = maxFrameSize; + } + + @Override + public boolean isOpen() + { + return delegate.isOpen(); + } + + @Override + public void close() throws IOException + { + delegate.close(); + } + + @Override + public int write( ByteBuffer src ) throws IOException + { + int originalLimit = src.limit(); + try + { + src.limit( Math.min( src.limit(), src.position() + maxFrameSize ) ); + return delegate.write( src ); + } + finally + { + src.limit(originalLimit); + } + } + + @Override + public int read( ByteBuffer dst ) throws IOException + { + int originalLimit = dst.limit(); + try + { + dst.limit( Math.min( dst.limit(), dst.position() + maxFrameSize ) ); + return delegate.read( dst ); + } + finally + { + dst.limit(originalLimit); + } + } + } +} diff --git a/driver/src/test/java/org/neo4j/driver/v1/integration/TLSSocketChannelIT.java b/driver/src/test/java/org/neo4j/driver/v1/integration/TLSSocketChannelIT.java index e851e14008..60b2aa39ea 100644 --- a/driver/src/test/java/org/neo4j/driver/v1/integration/TLSSocketChannelIT.java +++ b/driver/src/test/java/org/neo4j/driver/v1/integration/TLSSocketChannelIT.java @@ -91,8 +91,6 @@ private void performTLSHandshakeUsingKnownCerts( File knownCerts ) throws Throwa sslChannel.close(); // Then - verify( logger, atLeastOnce() ).debug( "TLS connection enabled" ); - verify( logger, atLeastOnce() ).debug( "TLS connection established" ); verify( logger, atLeastOnce() ).debug( "TLS connection closed" ); } @@ -135,8 +133,6 @@ public void shouldPerformTLSHandshakeWithTrustedCert() throws Throwable sslChannel.close(); // Then - verify( logger, atLeastOnce() ).debug( "TLS connection enabled" ); - verify( logger, atLeastOnce() ).debug( "TLS connection established" ); verify( logger, atLeastOnce() ).debug( "TLS connection closed" ); } finally @@ -249,8 +245,6 @@ public void shouldPerformTLSHandshakeWithTheSameTrustedServerCert() throws Throw sslChannel.close(); // Then - verify( logger, atLeastOnce() ).debug( "TLS connection enabled" ); - verify( logger, atLeastOnce() ).debug( "TLS connection established" ); verify( logger, atLeastOnce() ).debug( "TLS connection closed" ); } diff --git a/driver/src/test/resources/keystore.jks b/driver/src/test/resources/keystore.jks new file mode 100644 index 0000000000000000000000000000000000000000..59beefc44a1b72a53276ab7b6faddcf623775d28 GIT binary patch literal 2360 zcmdT_`9Bkm8{cN+9-EsnXCpLs2|1FvmSb|$nq#g_X0~dIxyvo0l9F>uzLl#IDTz5N zl3S9a=BN-YKKgu**XLjO{_yK0HQn9nnl*|jD(uZ$eb zM#BB2e&fykUm5RajBpel@G8otBdPW@IVYLAZaQ0 z{Myg4@T&H9jD*rNrAo8rvi$s5I%of4a89yHP*KPik^1(Et=w3p(+mtAMK9PCge`gl@nMr~DTRv3de zYd;&}fQ7JzzU+BuA=8fosgm#25lK}o`BQZq1Lp?)_pwW8;XcDJ}o^LiH^l2brB4$ zo%R$bC+(HpvtO@R2 z@0zoAk9$>kEHSq`{dM#~x5z`8ns{=_nFPYSkBM4mHlm^=8F`Y72b^HOev2(6dln=1 z>Uk`7ayRBa^CAa=#V)>(5;=0Oe=YSP`f(@iLmLS{VQ*s+tmN5}KBkS-?#pQ1cE1V; zqX&O%5y+AaKT>Bd{%E#?5)Hf3`*B8?p^+-C;=@Oc{~>Y((#Q(G74qW*<>BB;EQk0R zJMqj8?f7ZUG@GeDGql-BE0ePw*E?V2O#ASb$dVNtoNu>TsBCBXHtqv|dS{~?8dlbWs z^C`E&!A(k33X`1d-*xa@YrhE{D^`+N3Cnj4?LG+kK?6#4^U^pZMS~ytH588;_4i~_wqZ>=HNp19N+v(*(kS@4DWO6 zC##o5za-nWep8Wa1q#RepSzV(pKy_fxgD<>#)0}71#B%KUjGu8zeapJeVz*tJcmp= zX&f)~l-ZqyL28pUvNH74i_owz{7~9cdP4MgRKQ2dhRq!=50@IF#+kehytgZFSV7(Tf`(byzE!vGkT@ni2k_O#;icsHjRe^dNkn{N>z>d1?0q+m+gBTt#8NA{6Nw^t&*+G&w|xnW0I7=|0B1DwgKue z0Km0~VsOo&7@*!;Tp%C_1lH)I7NLZA_%xEy$(uqzpupb4j4#y@K=JW#xq*3LAU9VX zQ5574;DzxzhLFi5N>qrBEK2%MYaa}1MGB`xoDRWhBK}9JqLlxT{4hB?Tv#9(VH1Um zBp{CAqH(8jc|LNB3@fZ0wwlG#RU_x z!{JG%5ojE7Zw4YtUQ$TM5T&nkSVzy`kS@wgQb^YjrK@ZBANc>V1Ouq_PxJTEiHiY* z08k7dKZpSY0xBSR`bn#EC&*?d)mQCZJf|AKL%H9-qrakKh}X+Bci&8GKB;eXUw3?+ z1V8$WTMtR~?9Q=$LGYgSQcQSJ{RRj|+dNZ#1USTh=nXQX4f|PCWiQNt)$%a~U_UVlKO+<|&y9-A8ZP7G5 zA1BP=rMC7qi8}>p#xC?i5kbG#_56%E`h4^z^@kERc>!5FqN9SzvWd>nO0PI6HVCH( zf0J@i9U6Nt#&M`*H}n+5R1GcTN$kzvARiFGBd@2DY3BEOc-y|t+*nx}P=nVb`>#qT UYTaIO%69#iI_ge-s$n1GFM7uT3IG5A literal 0 HcmV?d00001