Skip to content

Commit

Permalink
Extracted factory for bolt protocol handlers
Browse files Browse the repository at this point in the history
Previously mapping from protocol version to a function that created
protocol handler was kept in a map. This map was also passed around
as-is. It was not pretty because map type was rather large and it
was mutable at any point.

This commit extracts a dedicated factory that is able to create
protocol handlers for the given version and network channel. Code
looks cleaner with this factory.
  • Loading branch information
lutovich committed Jan 9, 2018
1 parent cdd5884 commit 5b87ea0
Show file tree
Hide file tree
Showing 9 changed files with 251 additions and 107 deletions.
Expand Up @@ -23,15 +23,15 @@
import io.netty.util.internal.logging.InternalLoggerFactory;

import java.time.Clock;
import java.util.HashMap;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;

import org.neo4j.bolt.logging.BoltMessageLogging;
import org.neo4j.bolt.security.auth.Authentication;
import org.neo4j.bolt.security.auth.BasicAuthentication;
import org.neo4j.bolt.transport.BoltMessagingProtocolHandler;
import org.neo4j.bolt.transport.BoltProtocolHandlerFactory;
import org.neo4j.bolt.transport.DefaultBoltProtocolHandlerFactory;
import org.neo4j.bolt.transport.Netty4LoggerFactory;
import org.neo4j.bolt.transport.NettyServer;
import org.neo4j.bolt.transport.NettyServer.ProtocolInitializer;
Expand All @@ -42,7 +42,6 @@
import org.neo4j.bolt.v1.runtime.MonitoredWorkerFactory;
import org.neo4j.bolt.v1.runtime.WorkerFactory;
import org.neo4j.bolt.v1.runtime.concurrent.ThreadedWorkerFactory;
import org.neo4j.bolt.v1.transport.BoltMessagingProtocolV1Handler;
import org.neo4j.configuration.Description;
import org.neo4j.configuration.LoadableConfig;
import org.neo4j.graphdb.GraphDatabaseService;
Expand Down Expand Up @@ -150,6 +149,7 @@ public Lifecycle newInstance( KernelContext context, Dependencies dependencies )
ConnectorPortRegister connectionRegister = dependencies.connectionRegister();

TransportThrottleGroup throttleGroup = new TransportThrottleGroup( config );
BoltProtocolHandlerFactory handlerFactory = createHandlerFactory( workerFactory, throttleGroup, logService );

Map<BoltConnector, ProtocolInitializer> connectors = config.enabledBoltConnectors().stream()
.collect( Collectors.toMap( Function.identity(), connConfig ->
Expand Down Expand Up @@ -186,13 +186,11 @@ public Lifecycle newInstance( KernelContext context, Dependencies dependencies )
break;
}

final Map<Long, Function<BoltChannel, BoltMessagingProtocolHandler>> protocolHandlers =
getProtocolHandlers( logService, workerFactory, throttleGroup );
return new SocketTransport( listenAddress, sslCtx, requireEncryption, logService.getInternalLogProvider(),
boltLogging, throttleGroup, protocolHandlers );
return new SocketTransport( listenAddress, sslCtx, requireEncryption,
logService.getInternalLogProvider(), boltLogging, throttleGroup, handlerFactory );
} ) );

if ( connectors.size() > 0 && !config.get( GraphDatabaseSettings.disconnected ) )
if ( !connectors.isEmpty() && !config.get( GraphDatabaseSettings.disconnected ) )
{
life.add( new NettyServer( scheduler.threadFactory( boltNetworkIO ), connectors, connectionRegister ) );
log.info( "Bolt Server extension loaded." );
Expand Down Expand Up @@ -230,20 +228,14 @@ private SslContext createSslContext( SslPolicyLoader sslPolicyFactory, Config co
}
}

private Map<Long, Function<BoltChannel, BoltMessagingProtocolHandler>> getProtocolHandlers(
LogService logging, WorkerFactory workerFactory, TransportThrottleGroup throttleGroup )
private Authentication authentication( AuthManager authManager, UserManagerSupplier userManagerSupplier )
{
Map<Long, Function<BoltChannel, BoltMessagingProtocolHandler>> protocolHandlers = new HashMap<>();
protocolHandlers.put(
(long) BoltMessagingProtocolV1Handler.VERSION,
boltChannel ->
new BoltMessagingProtocolV1Handler( boltChannel, workerFactory.newWorker( boltChannel ), throttleGroup, logging )
);
return protocolHandlers;
return new BasicAuthentication( authManager, userManagerSupplier );
}

private Authentication authentication( AuthManager authManager, UserManagerSupplier userManagerSupplier )
private static BoltProtocolHandlerFactory createHandlerFactory( WorkerFactory workerFactory,
TransportThrottleGroup throttleGroup, LogService logService )
{
return new BasicAuthentication( authManager, userManagerSupplier );
return new DefaultBoltProtocolHandlerFactory( workerFactory, throttleGroup, logService );
}
}
Expand Up @@ -23,8 +23,6 @@

import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.Map;
import java.util.function.Function;

import org.neo4j.bolt.BoltChannel;

Expand All @@ -43,22 +41,22 @@ public class BoltHandshakeProtocolHandler
{
public static final int BOLT_MAGIC_PREAMBLE = 0x6060B017;

private final Map<Long,Function<BoltChannel, BoltMessagingProtocolHandler>> protocolHandlers;
private final BoltProtocolHandlerFactory handlerFactory;
private final boolean encryptionRequired;
private final boolean isEncrypted;
private final ByteBuffer handshakeBuffer = ByteBuffer.allocate( 5 * 4 ).order( ByteOrder.BIG_ENDIAN );

private BoltMessagingProtocolHandler protocol;

/**
* @param protocolHandlers version -> protocol mapping
* @param handlerFactory the factory to create protocol for specific version
* @param encryptionRequired whether or not the server allows only encrypted connections
* @param isEncrypted whether of not this connection is encrypted
*/
public BoltHandshakeProtocolHandler( Map<Long,Function<BoltChannel, BoltMessagingProtocolHandler>> protocolHandlers,
public BoltHandshakeProtocolHandler( BoltProtocolHandlerFactory handlerFactory,
boolean encryptionRequired, boolean isEncrypted )
{
this.protocolHandlers = protocolHandlers;
this.handlerFactory = handlerFactory;
this.encryptionRequired = encryptionRequired;
this.isEncrypted = isEncrypted;
}
Expand Down Expand Up @@ -97,9 +95,10 @@ else if ( handshakeBuffer.remaining() > buffer.readableBytes() )
for ( int i = 0; i < 4; i++ )
{
long suggestion = handshakeBuffer.getInt() & 0xFFFFFFFFL;
if ( protocolHandlers.containsKey( suggestion ) )

protocol = handlerFactory.create( suggestion, boltChannel );
if ( protocol != null )
{
protocol = protocolHandlers.get( suggestion ).apply( boltChannel );
boltChannel.log().serverEvent( "HANDSHAKE", () -> format( "0x%02X", protocol.version() ) );
return HandshakeOutcome.PROTOCOL_CHOSEN;
}
Expand Down
@@ -0,0 +1,41 @@
/*
* Copyright (c) 2002-2018 "Neo Technology,"
* Network Engine for Objects in Lund AB [http://neotechnology.com]
*
* This file is part of Neo4j.
*
* Neo4j is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.neo4j.bolt.transport;

import org.neo4j.bolt.BoltChannel;

/**
* Represents a component that instantiates Bolt protocol handlers.
*
* @see BoltMessagingProtocolHandler
*/
@FunctionalInterface
public interface BoltProtocolHandlerFactory
{
/**
* Instantiate a handler for Bolt protocol with the specified version. Return {@code null} when handler for the
* given version can't be instantiated.
*
* @param protocolVersion the version as negishiated by the initial handshake.
* @param channel the channel representing network connection from the client.
* @return new protocol handler when given protocol version is known and valid, {@code null} otherwise.
*/
BoltMessagingProtocolHandler create( long protocolVersion, BoltChannel channel );
}
@@ -0,0 +1,55 @@
/*
* Copyright (c) 2002-2018 "Neo Technology,"
* Network Engine for Objects in Lund AB [http://neotechnology.com]
*
* This file is part of Neo4j.
*
* Neo4j is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.neo4j.bolt.transport;

import org.neo4j.bolt.BoltChannel;
import org.neo4j.bolt.v1.runtime.BoltWorker;
import org.neo4j.bolt.v1.runtime.WorkerFactory;
import org.neo4j.bolt.v1.transport.BoltMessagingProtocolV1Handler;
import org.neo4j.kernel.impl.logging.LogService;

public class DefaultBoltProtocolHandlerFactory implements BoltProtocolHandlerFactory
{
private final WorkerFactory workerFactory;
private final TransportThrottleGroup throttleGroup;
private final LogService logService;

public DefaultBoltProtocolHandlerFactory( WorkerFactory workerFactory, TransportThrottleGroup throttleGroup,
LogService logService )
{
this.workerFactory = workerFactory;
this.throttleGroup = throttleGroup;
this.logService = logService;
}

@Override
public BoltMessagingProtocolHandler create( long protocolVersion, BoltChannel channel )
{
if ( protocolVersion == BoltMessagingProtocolV1Handler.VERSION )
{
BoltWorker worker = workerFactory.newWorker( channel );
return new BoltMessagingProtocolV1Handler( channel, worker, throttleGroup, logService );
}
else
{
return null;
}
}
}
Expand Up @@ -24,10 +24,6 @@
import io.netty.channel.socket.SocketChannel;
import io.netty.handler.ssl.SslContext;

import java.util.Map;
import java.util.function.Function;

import org.neo4j.bolt.BoltChannel;
import org.neo4j.bolt.logging.BoltMessageLogging;
import org.neo4j.helpers.ListenSocketAddress;
import org.neo4j.logging.LogProvider;
Expand All @@ -43,20 +39,20 @@ public class SocketTransport implements NettyServer.ProtocolInitializer
private final LogProvider logging;
private final BoltMessageLogging boltLogging;
private final TransportThrottleGroup throttleGroup;
private final Map<Long, Function<BoltChannel, BoltMessagingProtocolHandler>> protocolVersions;
private final BoltProtocolHandlerFactory handlerFactory;

public SocketTransport( ListenSocketAddress address, SslContext sslCtx, boolean encryptionRequired,
LogProvider logging, BoltMessageLogging boltLogging,
TransportThrottleGroup throttleGroup,
Map<Long, Function<BoltChannel, BoltMessagingProtocolHandler>> protocolVersions )
BoltProtocolHandlerFactory handlerFactory )
{
this.address = address;
this.sslCtx = sslCtx;
this.encryptionRequired = encryptionRequired;
this.logging = logging;
this.boltLogging = boltLogging;
this.throttleGroup = throttleGroup;
this.protocolVersions = protocolVersions;
this.handlerFactory = handlerFactory;
}

@Override
Expand All @@ -75,9 +71,10 @@ public void initChannel( SocketChannel ch ) throws Exception
// add a close listener that will uninstall throttles
ch.closeFuture().addListener( future -> throttleGroup.uninstall( ch ) );

ch.pipeline().addLast(
new TransportSelectionHandler( sslCtx, encryptionRequired, false, logging, protocolVersions,
boltLogging ) );
TransportSelectionHandler transportSelectionHandler = new TransportSelectionHandler( sslCtx,
encryptionRequired, false, logging, handlerFactory, boltLogging );

ch.pipeline().addLast( transportSelectionHandler );
}
};
}
Expand Down
Expand Up @@ -31,10 +31,7 @@
import io.netty.handler.ssl.SslHandler;

import java.util.List;
import java.util.Map;
import java.util.function.Function;

import org.neo4j.bolt.BoltChannel;
import org.neo4j.bolt.logging.BoltMessageLogging;
import org.neo4j.logging.LogProvider;

Expand All @@ -51,18 +48,18 @@ public class TransportSelectionHandler extends ByteToMessageDecoder
private final boolean isEncrypted;
private final LogProvider logging;
private final BoltMessageLogging boltLogging;
private final Map<Long, Function<BoltChannel, BoltMessagingProtocolHandler>> protocolVersions;
private final BoltProtocolHandlerFactory handlerFactory;

TransportSelectionHandler( SslContext sslCtx, boolean encryptionRequired, boolean isEncrypted, LogProvider logging,
Map<Long, Function<BoltChannel, BoltMessagingProtocolHandler>> protocolVersions,
BoltProtocolHandlerFactory handlerFactory,
BoltMessageLogging boltLogging )
{
this.sslCtx = sslCtx;
this.encryptionRequired = encryptionRequired;
this.isEncrypted = isEncrypted;
this.logging = logging;
this.boltLogging = boltLogging;
this.protocolVersions = protocolVersions;
this.handlerFactory = handlerFactory;
}

@Override
Expand Down Expand Up @@ -121,7 +118,7 @@ private void enableSsl( ChannelHandlerContext ctx )
ChannelPipeline p = ctx.pipeline();
p.addLast( sslCtx.newHandler( ctx.alloc() ) );
p.addLast( new TransportSelectionHandler( null, encryptionRequired, true, logging,
protocolVersions, boltLogging ) );
handlerFactory, boltLogging ) );
p.remove( this );
}

Expand All @@ -147,7 +144,7 @@ private void switchToWebsocket( ChannelHandlerContext ctx )

private SocketTransportHandler newSocketTransportHandler()
{
BoltHandshakeProtocolHandler protocolHandler = new BoltHandshakeProtocolHandler( protocolVersions,
BoltHandshakeProtocolHandler protocolHandler = new BoltHandshakeProtocolHandler( handlerFactory,
encryptionRequired, isEncrypted );
return new SocketTransportHandler( protocolHandler, logging, boltLogging );
}
Expand Down

0 comments on commit 5b87ea0

Please sign in to comment.