Skip to content

Commit

Permalink
Refactor Netty hostname verification to clarify that it is client side
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewkerr9000 committed Sep 14, 2018
1 parent a080a20 commit 6428e05
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 19 deletions.
Expand Up @@ -23,7 +23,10 @@
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLParameters;

public class HostnameVerificationEngineModification implements Function<SSLEngine,SSLEngine>
/**
* Client side modifier for SSLEngine to mandate hostname verification
*/
public class ClientSideHostnameVerificationEngineModification implements Function<SSLEngine,SSLEngine>
{
/**
* Apply modifications to engine to enable hostname verification (client side only)
Expand Down
Expand Up @@ -37,22 +37,22 @@
import java.util.function.Function;
import javax.net.ssl.SSLEngine;

public class OnConnectSslHandler extends ChannelDuplexHandler
public class ClientSideOnConnectSslHandler extends ChannelDuplexHandler
{
private final ChannelPipeline pipeline;
private final SslContext sslContext;
private final Collection<Function<SSLEngine,SSLEngine>> engineModifications;

OnConnectSslHandler( Channel channel, SslContext sslContext, boolean isClient, boolean verifyHostname, String[] tlsVersions )
ClientSideOnConnectSslHandler( Channel channel, SslContext sslContext, boolean verifyHostname, String[] tlsVersions )
{
this.pipeline = channel.pipeline();
this.sslContext = sslContext;

this.engineModifications = new ArrayList<>();
engineModifications.add( new EssentialEngineModifications( tlsVersions, isClient ) );
engineModifications.add( new EssentialEngineModifications( tlsVersions, true ) );
if ( verifyHostname )
{
engineModifications.add( new HostnameVerificationEngineModification() );
engineModifications.add( new ClientSideHostnameVerificationEngineModification() );
}
}

Expand Down
9 changes: 3 additions & 6 deletions community/ssl/src/main/java/org/neo4j/ssl/SslPolicy.java
Expand Up @@ -35,7 +35,6 @@
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLException;
import javax.net.ssl.TrustManagerFactory;
import javax.xml.bind.DatatypeConverter;

import org.neo4j.logging.Log;
import org.neo4j.logging.LogProvider;
Expand All @@ -55,7 +54,6 @@ public class SslPolicy
private final SslProvider sslProvider;

private final boolean verifyHostname;
private final LogProvider logProvider;
private final Log log;

public SslPolicy( PrivateKey privateKey, X509Certificate[] keyCertChain, List<String> tlsVersions, List<String> ciphers, ClientAuth clientAuth,
Expand All @@ -69,7 +67,6 @@ public SslPolicy( PrivateKey privateKey, X509Certificate[] keyCertChain, List<St
this.trustManagerFactory = trustManagerFactory;
this.sslProvider = sslProvider;
this.verifyHostname = verifyHostname;
this.logProvider = logProvider;
this.log = logProvider.getLog( SslPolicy.class );
}

Expand Down Expand Up @@ -116,7 +113,7 @@ public ChannelHandler nettyServerHandler( Channel channel ) throws SSLException
return nettyServerHandler( channel, nettyServerContext() );
}

private ChannelHandler nettyServerHandler( Channel channel, SslContext sslContext ) throws SSLException
private ChannelHandler nettyServerHandler( Channel channel, SslContext sslContext )
{
SSLEngine sslEngine = sslContext.newEngine( channel.alloc() );
return new SslHandler( sslEngine );
Expand All @@ -128,9 +125,9 @@ public ChannelHandler nettyClientHandler( Channel channel ) throws SSLException
return nettyClientHandler( channel, nettyClientContext() );
}

ChannelHandler nettyClientHandler( Channel channel, SslContext sslContext ) throws SSLException
ChannelHandler nettyClientHandler( Channel channel, SslContext sslContext )
{
return new OnConnectSslHandler( channel, sslContext, true, verifyHostname, tlsVersions );
return new ClientSideOnConnectSslHandler( channel, sslContext, verifyHostname, tlsVersions );
}

public PrivateKey privateKey()
Expand Down
15 changes: 7 additions & 8 deletions integrationtests/src/test/java/org/neo4j/ssl/SecureClient.java
Expand Up @@ -27,6 +27,7 @@
import io.netty.buffer.ByteBufAllocator;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInitializer;
Expand All @@ -49,7 +50,6 @@
public class SecureClient
{
private Bootstrap bootstrap;
private ClientInitializer clientInitializer;
private NioEventLoopGroup eventLoopGroup;
private Channel channel;
private Bucket bucket = new Bucket();
Expand All @@ -62,10 +62,9 @@ public class SecureClient
public SecureClient( SslPolicy sslPolicy ) throws SSLException
{
eventLoopGroup = new NioEventLoopGroup();
clientInitializer = new ClientInitializer( sslPolicy, bucket );
bootstrap = new Bootstrap().group( eventLoopGroup )
.channel( NioSocketChannel.class )
.handler( clientInitializer );
.handler( new ClientInitializer( sslPolicy, bucket ) );
}

public Future<Channel> sslHandshakeFuture()
Expand Down Expand Up @@ -132,13 +131,13 @@ static class Bucket extends SimpleChannelInboundHandler<ByteBuf>
}

@Override
protected void channelRead0( ChannelHandlerContext ctx, ByteBuf msg ) throws Exception
protected void channelRead0( ChannelHandlerContext ctx, ByteBuf msg )
{
collectedData.writeBytes( msg );
}

@Override
public void exceptionCaught( ChannelHandlerContext ctx, Throwable cause ) throws Exception
public void exceptionCaught( ChannelHandlerContext ctx, Throwable cause )
{
}
}
Expand All @@ -157,13 +156,13 @@ public class ClientInitializer extends ChannelInitializer<SocketChannel>
}

@Override
protected void initChannel( SocketChannel channel ) throws Exception
protected void initChannel( SocketChannel channel )
{
ChannelPipeline pipeline = channel.pipeline();

OnConnectSslHandler onConnectSslHandler = (OnConnectSslHandler) sslPolicy.nettyClientHandler( channel, sslContext );
ChannelHandler clientOnConnectSslHandler = sslPolicy.nettyClientHandler( channel, sslContext );

pipeline.addLast( onConnectSslHandler );
pipeline.addLast( clientOnConnectSslHandler );
pipeline.addLast( new ChannelInboundHandlerAdapter()
{
@Override
Expand Down

0 comments on commit 6428e05

Please sign in to comment.