Skip to content

Commit

Permalink
Refactor Netty hostname verification to clarify that it is client side
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewkerr9000 committed Sep 14, 2018
1 parent a080a20 commit 6428e05
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 19 deletions.
Expand Up @@ -23,7 +23,10 @@
import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLParameters; import javax.net.ssl.SSLParameters;


public class HostnameVerificationEngineModification implements Function<SSLEngine,SSLEngine> /**
* Client side modifier for SSLEngine to mandate hostname verification
*/
public class ClientSideHostnameVerificationEngineModification implements Function<SSLEngine,SSLEngine>
{ {
/** /**
* Apply modifications to engine to enable hostname verification (client side only) * Apply modifications to engine to enable hostname verification (client side only)
Expand Down
Expand Up @@ -37,22 +37,22 @@
import java.util.function.Function; import java.util.function.Function;
import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLEngine;


public class OnConnectSslHandler extends ChannelDuplexHandler public class ClientSideOnConnectSslHandler extends ChannelDuplexHandler
{ {
private final ChannelPipeline pipeline; private final ChannelPipeline pipeline;
private final SslContext sslContext; private final SslContext sslContext;
private final Collection<Function<SSLEngine,SSLEngine>> engineModifications; private final Collection<Function<SSLEngine,SSLEngine>> engineModifications;


OnConnectSslHandler( Channel channel, SslContext sslContext, boolean isClient, boolean verifyHostname, String[] tlsVersions ) ClientSideOnConnectSslHandler( Channel channel, SslContext sslContext, boolean verifyHostname, String[] tlsVersions )
{ {
this.pipeline = channel.pipeline(); this.pipeline = channel.pipeline();
this.sslContext = sslContext; this.sslContext = sslContext;


this.engineModifications = new ArrayList<>(); this.engineModifications = new ArrayList<>();
engineModifications.add( new EssentialEngineModifications( tlsVersions, isClient ) ); engineModifications.add( new EssentialEngineModifications( tlsVersions, true ) );
if ( verifyHostname ) if ( verifyHostname )
{ {
engineModifications.add( new HostnameVerificationEngineModification() ); engineModifications.add( new ClientSideHostnameVerificationEngineModification() );
} }
} }


Expand Down
9 changes: 3 additions & 6 deletions community/ssl/src/main/java/org/neo4j/ssl/SslPolicy.java
Expand Up @@ -35,7 +35,6 @@
import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLException; import javax.net.ssl.SSLException;
import javax.net.ssl.TrustManagerFactory; import javax.net.ssl.TrustManagerFactory;
import javax.xml.bind.DatatypeConverter;


import org.neo4j.logging.Log; import org.neo4j.logging.Log;
import org.neo4j.logging.LogProvider; import org.neo4j.logging.LogProvider;
Expand All @@ -55,7 +54,6 @@ public class SslPolicy
private final SslProvider sslProvider; private final SslProvider sslProvider;


private final boolean verifyHostname; private final boolean verifyHostname;
private final LogProvider logProvider;
private final Log log; private final Log log;


public SslPolicy( PrivateKey privateKey, X509Certificate[] keyCertChain, List<String> tlsVersions, List<String> ciphers, ClientAuth clientAuth, public SslPolicy( PrivateKey privateKey, X509Certificate[] keyCertChain, List<String> tlsVersions, List<String> ciphers, ClientAuth clientAuth,
Expand All @@ -69,7 +67,6 @@ public SslPolicy( PrivateKey privateKey, X509Certificate[] keyCertChain, List<St
this.trustManagerFactory = trustManagerFactory; this.trustManagerFactory = trustManagerFactory;
this.sslProvider = sslProvider; this.sslProvider = sslProvider;
this.verifyHostname = verifyHostname; this.verifyHostname = verifyHostname;
this.logProvider = logProvider;
this.log = logProvider.getLog( SslPolicy.class ); this.log = logProvider.getLog( SslPolicy.class );
} }


Expand Down Expand Up @@ -116,7 +113,7 @@ public ChannelHandler nettyServerHandler( Channel channel ) throws SSLException
return nettyServerHandler( channel, nettyServerContext() ); return nettyServerHandler( channel, nettyServerContext() );
} }


private ChannelHandler nettyServerHandler( Channel channel, SslContext sslContext ) throws SSLException private ChannelHandler nettyServerHandler( Channel channel, SslContext sslContext )
{ {
SSLEngine sslEngine = sslContext.newEngine( channel.alloc() ); SSLEngine sslEngine = sslContext.newEngine( channel.alloc() );
return new SslHandler( sslEngine ); return new SslHandler( sslEngine );
Expand All @@ -128,9 +125,9 @@ public ChannelHandler nettyClientHandler( Channel channel ) throws SSLException
return nettyClientHandler( channel, nettyClientContext() ); return nettyClientHandler( channel, nettyClientContext() );
} }


ChannelHandler nettyClientHandler( Channel channel, SslContext sslContext ) throws SSLException ChannelHandler nettyClientHandler( Channel channel, SslContext sslContext )
{ {
return new OnConnectSslHandler( channel, sslContext, true, verifyHostname, tlsVersions ); return new ClientSideOnConnectSslHandler( channel, sslContext, verifyHostname, tlsVersions );
} }


public PrivateKey privateKey() public PrivateKey privateKey()
Expand Down
15 changes: 7 additions & 8 deletions integrationtests/src/test/java/org/neo4j/ssl/SecureClient.java
Expand Up @@ -27,6 +27,7 @@
import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.ByteBufAllocator;
import io.netty.channel.Channel; import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelInitializer;
Expand All @@ -49,7 +50,6 @@
public class SecureClient public class SecureClient
{ {
private Bootstrap bootstrap; private Bootstrap bootstrap;
private ClientInitializer clientInitializer;
private NioEventLoopGroup eventLoopGroup; private NioEventLoopGroup eventLoopGroup;
private Channel channel; private Channel channel;
private Bucket bucket = new Bucket(); private Bucket bucket = new Bucket();
Expand All @@ -62,10 +62,9 @@ public class SecureClient
public SecureClient( SslPolicy sslPolicy ) throws SSLException public SecureClient( SslPolicy sslPolicy ) throws SSLException
{ {
eventLoopGroup = new NioEventLoopGroup(); eventLoopGroup = new NioEventLoopGroup();
clientInitializer = new ClientInitializer( sslPolicy, bucket );
bootstrap = new Bootstrap().group( eventLoopGroup ) bootstrap = new Bootstrap().group( eventLoopGroup )
.channel( NioSocketChannel.class ) .channel( NioSocketChannel.class )
.handler( clientInitializer ); .handler( new ClientInitializer( sslPolicy, bucket ) );
} }


public Future<Channel> sslHandshakeFuture() public Future<Channel> sslHandshakeFuture()
Expand Down Expand Up @@ -132,13 +131,13 @@ static class Bucket extends SimpleChannelInboundHandler<ByteBuf>
} }


@Override @Override
protected void channelRead0( ChannelHandlerContext ctx, ByteBuf msg ) throws Exception protected void channelRead0( ChannelHandlerContext ctx, ByteBuf msg )
{ {
collectedData.writeBytes( msg ); collectedData.writeBytes( msg );
} }


@Override @Override
public void exceptionCaught( ChannelHandlerContext ctx, Throwable cause ) throws Exception public void exceptionCaught( ChannelHandlerContext ctx, Throwable cause )
{ {
} }
} }
Expand All @@ -157,13 +156,13 @@ public class ClientInitializer extends ChannelInitializer<SocketChannel>
} }


@Override @Override
protected void initChannel( SocketChannel channel ) throws Exception protected void initChannel( SocketChannel channel )
{ {
ChannelPipeline pipeline = channel.pipeline(); ChannelPipeline pipeline = channel.pipeline();


OnConnectSslHandler onConnectSslHandler = (OnConnectSslHandler) sslPolicy.nettyClientHandler( channel, sslContext ); ChannelHandler clientOnConnectSslHandler = sslPolicy.nettyClientHandler( channel, sslContext );


pipeline.addLast( onConnectSslHandler ); pipeline.addLast( clientOnConnectSslHandler );
pipeline.addLast( new ChannelInboundHandlerAdapter() pipeline.addLast( new ChannelInboundHandlerAdapter()
{ {
@Override @Override
Expand Down

0 comments on commit 6428e05

Please sign in to comment.