Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ public static ByteChannel create( String host, int port, Config config, Logger l
SocketChannel soChannel = SocketChannel.open();
soChannel.setOption( StandardSocketOptions.SO_REUSEADDR, true );
soChannel.setOption( StandardSocketOptions.SO_KEEPALIVE, true );

soChannel.connect( new InetSocketAddress( host, port ) );

ByteChannel channel;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,12 @@
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.ByteChannel;
import java.nio.channels.SocketChannel;
import java.security.GeneralSecurityException;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLEngineResult;
import javax.net.ssl.SSLEngineResult.HandshakeStatus;
import javax.net.ssl.SSLEngineResult.Status;
import javax.net.ssl.SSLSession;

import org.neo4j.driver.internal.spi.Logger;
import org.neo4j.driver.internal.util.BytePrinter;
Expand All @@ -51,10 +49,9 @@
*/
public class TLSSocketChannel implements ByteChannel
{
private final SocketChannel channel; // The real channel the data is sent to and read from
private final ByteChannel channel; // The real channel the data is sent to and read from
private final Logger logger;

private final SSLContext sslContext;
private SSLEngine sslEngine;

/** The buffer for network data */
Expand All @@ -66,34 +63,36 @@ public class TLSSocketChannel implements ByteChannel

private static final ByteBuffer DUMMY_BUFFER = ByteBuffer.allocateDirect( 0 );

public TLSSocketChannel( String host, int port, SocketChannel channel, Logger logger,
public TLSSocketChannel( String host, int port, ByteChannel channel, Logger logger,
TrustStrategy trustStrategy )
throws GeneralSecurityException, IOException
{
logger.debug( "TLS connection enabled" );
this.logger = logger;
this.channel = channel;
this.channel.configureBlocking( true );
this(channel, logger,
createSSLEngine( host, port, new SSLContextFactory( host, port, trustStrategy, logger ).create() ) );

sslContext = new SSLContextFactory( host, port, trustStrategy, logger ).create();
createSSLEngine( host, port );
createBuffers();
runHandshake();
logger.debug( "TLS connection established" );
}

/** Used in internal tests only */
TLSSocketChannel( SocketChannel channel, Logger logger, SSLEngine sslEngine,
public TLSSocketChannel( ByteChannel channel, Logger logger, SSLEngine sslEngine ) throws GeneralSecurityException, IOException
{
this(channel, logger, sslEngine,
ByteBuffer.allocateDirect( sslEngine.getSession().getApplicationBufferSize() ),
ByteBuffer.allocateDirect( sslEngine.getSession().getPacketBufferSize() ),
ByteBuffer.allocateDirect( sslEngine.getSession().getApplicationBufferSize() ),
ByteBuffer.allocateDirect( sslEngine.getSession().getPacketBufferSize() ) );
}

TLSSocketChannel( ByteChannel channel, Logger logger, SSLEngine sslEngine,
ByteBuffer plainIn, ByteBuffer cipherIn, ByteBuffer plainOut, ByteBuffer cipherOut )
throws GeneralSecurityException, IOException
{
logger.debug( "Testing TLS buffers" );
this.logger = logger;
this.channel = channel;

this.sslContext = SSLContext.getInstance( "TLS" );
this.sslEngine = sslEngine;
resetBuffers( plainIn, cipherIn, plainOut, cipherOut ); // reset buffer size
this.plainIn = plainIn;
this.cipherIn = cipherIn;
this.plainOut = plainOut;
this.cipherOut = cipherOut;
runHandshake();
}

/**
Expand Down Expand Up @@ -126,17 +125,13 @@ private void runHandshake() throws IOException
case NEED_UNWRAP:
// Unwrap the ssl packet to value ssl handshake information
handshakeStatus = unwrap( DUMMY_BUFFER );
plainIn.clear();
break;
case NEED_WRAP:
// Wrap the app packet into an ssl packet to add ssl handshake information
handshakeStatus = wrap( plainOut );
break;
}
}

plainIn.clear();
plainOut.clear();
}

private HandshakeStatus runDelegatedTasks()
Expand Down Expand Up @@ -185,10 +180,11 @@ private HandshakeStatus unwrap( ByteBuffer buffer ) throws IOException
}
cipherIn.flip();

Status status = null;
Status status;
do
{
status = sslEngine.unwrap( cipherIn, plainIn ).getStatus();
SSLEngineResult unwrapResult = sslEngine.unwrap( cipherIn, plainIn );
status = unwrapResult.getStatus();
// Possible status here:
// OK - good
// BUFFER_OVERFLOW - we need to enlarge* plainIn
Expand Down Expand Up @@ -244,17 +240,13 @@ private HandshakeStatus unwrap( ByteBuffer buffer ) throws IOException
// Otherwise, make room for reading more data from channel
cipherIn.compact();
}

// I skipped the following check as it "should not" happen at all:
// The channel should not provide us ciphered bytes that cannot hold in the channel buffer at all
// if( cipherIn.remaining() == 0 )
// {throw new ClientException( "cannot enlarge as it already reached the limit" );}

// Obtain more inbound network data for cipherIn,
// then retry the operation.
return handshakeStatus; // old status
case CLOSED:
// RFC 2246 #7.2.1 requires us to stop accepting input.
sslEngine.closeInbound();
break;
default:
throw new ClientException( "Got unexpected status " + status );
throw new ClientException( "Got unexpected status " + status + ", " + unwrapResult );
}
}
while ( cipherIn.hasRemaining() ); /* Remember we are doing blocking reading.
Expand Down Expand Up @@ -285,7 +277,10 @@ private HandshakeStatus wrap( ByteBuffer buffer ) throws IOException
case OK:
handshakeStatus = runDelegatedTasks();
cipherOut.flip();
channel.write( cipherOut );
while(cipherOut.hasRemaining())
{
channel.write( cipherOut );
}
cipherOut.clear();
break;
case BUFFER_OVERFLOW:
Expand Down Expand Up @@ -344,42 +339,17 @@ static int bufferCopy( ByteBuffer from, ByteBuffer to )
return maxTransfer;
}

/**
* Create network buffers and application buffers
*
* @throws IOException
*/
private void createBuffers() throws IOException
{
SSLSession session = sslEngine.getSession();
int appBufferSize = session.getApplicationBufferSize();
int netBufferSize = session.getPacketBufferSize();

plainOut = ByteBuffer.allocateDirect( appBufferSize );
plainIn = ByteBuffer.allocateDirect( appBufferSize );
cipherOut = ByteBuffer.allocateDirect( netBufferSize );
cipherIn = ByteBuffer.allocateDirect( netBufferSize );
}

/** Should only be used in tests */
void resetBuffers( ByteBuffer plainIn, ByteBuffer cipherIn, ByteBuffer plainOut, ByteBuffer cipherOut )
{
this.plainIn = plainIn;
this.cipherIn = cipherIn;
this.plainOut = plainOut;
this.cipherOut = cipherOut;
}

/**
* Create SSLEngine with the SSLContext just created.
*
* @param host
* @param port
* @param sslContext
*/
private void createSSLEngine( String host, int port )
private static SSLEngine createSSLEngine( String host, int port, SSLContext sslContext )
{
sslEngine = sslContext.createSSLEngine( host, port );
SSLEngine sslEngine = sslContext.createSSLEngine( host, port );
sslEngine.setUseClientMode( true );
return sslEngine;
}

@Override
Expand Down Expand Up @@ -431,33 +401,46 @@ public boolean isOpen()
@Override
public void close() throws IOException
{
plainOut.clear();
// Indicate that application is done with engine
sslEngine.closeOutbound();

while ( !sslEngine.isOutboundDone() )
try
{
// Get close message
SSLEngineResult res = sslEngine.wrap( plainOut, cipherOut );

// Check res statuses
plainOut.clear();
// Indicate that application is done with engine
sslEngine.closeOutbound();

// Send close message to peer
cipherOut.flip();
while ( cipherOut.hasRemaining() )
while ( !sslEngine.isOutboundDone() )
{
int num = channel.write( cipherOut );
if ( num == -1 )
// Get close message
SSLEngineResult res = sslEngine.wrap( plainOut, cipherOut );

// Check res statuses

// Send close message to peer
cipherOut.flip();
while ( cipherOut.hasRemaining() )
{
// handle closed channel
break;
int num = channel.write( cipherOut );
if ( num == -1 )
{
// handle closed channel
break;
}
}
cipherOut.clear();
}
cipherOut.clear();
// Close transport
channel.close();
logger.debug( "TLS connection closed" );
}
catch(IOException e)
{
// Treat this as ok - the connection is closed, even if the TLS session did not exit cleanly.
logger.warn( "TLS socket could not be closed cleanly: '"+e.getMessage()+"'", e );
}
// Close transport
channel.close();
logger.debug( "TLS connection closed" );
}

@Override
public String toString()
{
return "TLSSocketChannel{plainIn: " + plainIn + ", cipherIn:" + cipherIn + "}";
}
}
Loading