diff --git a/community/bolt/src/main/java/org/neo4j/bolt/BoltKernelExtension.java b/community/bolt/src/main/java/org/neo4j/bolt/BoltKernelExtension.java index 5f6097068d5f..dfca29fba51d 100644 --- a/community/bolt/src/main/java/org/neo4j/bolt/BoltKernelExtension.java +++ b/community/bolt/src/main/java/org/neo4j/bolt/BoltKernelExtension.java @@ -155,12 +155,10 @@ public Lifecycle newInstance( KernelContext context, Dependencies dependencies ) Authentication authentication = authentication( dependencies.authManager() ); - BoltFactory boltConnectionManagerFactory = life.add( - new LifecycleManagedBoltFactory( api, dependencies.usageData(), logService, dependencies.txBridge(), - authentication, dependencies.sessionTracker() ) ); - ThreadedWorkerFactory threadedSessions = new ThreadedWorkerFactory( boltConnectionManagerFactory, scheduler, logService ); - WorkerFactory workerFactory = new MonitoredWorkerFactory( dependencies.monitors(), threadedSessions, - Clocks.systemClock() ); + BoltFactory boltFactory = life.add( new LifecycleManagedBoltFactory( api, dependencies.usageData(), + logService, dependencies.txBridge(), authentication, dependencies.sessionTracker() ) ); + + WorkerFactory workerFactory = createWorkerFactory( boltFactory, scheduler, dependencies, logService ); List connectors = boltConnectors( config ).stream() .map( ( connConfig ) -> { @@ -216,6 +214,13 @@ public Lifecycle newInstance( KernelContext context, Dependencies dependencies ) return life; } + protected WorkerFactory createWorkerFactory( BoltFactory boltFactory, JobScheduler scheduler, + Dependencies dependencies, LogService logService ) + { + WorkerFactory threadedWorkerFactory = new ThreadedWorkerFactory( boltFactory, scheduler, logService ); + return new MonitoredWorkerFactory( dependencies.monitors(), threadedWorkerFactory, Clocks.systemClock() ); + } + private SslContext createSslContext( Config config, Log log, AdvertisedSocketAddress address ) { try diff --git a/community/bolt/src/main/java/org/neo4j/bolt/transport/SocketTransportHandler.java b/community/bolt/src/main/java/org/neo4j/bolt/transport/SocketTransportHandler.java index cc74b7c0c3e4..1a44abe1e350 100644 --- a/community/bolt/src/main/java/org/neo4j/bolt/transport/SocketTransportHandler.java +++ b/community/bolt/src/main/java/org/neo4j/bolt/transport/SocketTransportHandler.java @@ -23,14 +23,15 @@ import io.netty.channel.Channel; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; -import org.neo4j.logging.Log; -import org.neo4j.logging.LogProvider; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.util.Map; import java.util.function.BiFunction; +import org.neo4j.logging.Log; +import org.neo4j.logging.LogProvider; + import static io.netty.buffer.Unpooled.wrappedBuffer; /** @@ -76,29 +77,37 @@ public void channelRead( ChannelHandlerContext ctx, Object msg ) throws Exceptio @Override public void channelInactive( ChannelHandlerContext ctx ) throws Exception { - close(); + close( ctx ); } @Override public void handlerRemoved( ChannelHandlerContext ctx ) throws Exception { - close(); + close( ctx ); } @Override public void exceptionCaught( ChannelHandlerContext ctx, Throwable cause ) throws Exception { log.error( "Fatal error occurred when handling a client connection: " + cause.getMessage(), cause ); - close(); + close( ctx ); } - private void close() + private void close( ChannelHandlerContext ctx ) { - if(protocol != null) + if ( protocol != null ) { + // handshake was successful and protocol was initialized, so it needs to be closed now + // channel will be closed as part of the protocol's close procedure protocol.close(); protocol = null; } + else + { + // handshake did not happen or failed, protocol was not initialized, so we need to close the channel + // channel will be closed as part of the context's close procedure + ctx.close(); + } } private void chooseProtocolVersion( ChannelHandlerContext ctx, ByteBuf buffer ) throws Exception diff --git a/community/bolt/src/test/java/org/neo4j/bolt/v1/transport/socket/SocketTransportHandlerTest.java b/community/bolt/src/test/java/org/neo4j/bolt/v1/transport/socket/SocketTransportHandlerTest.java index 822d54ab2c45..b80ca4bbbbf3 100644 --- a/community/bolt/src/test/java/org/neo4j/bolt/v1/transport/socket/SocketTransportHandlerTest.java +++ b/community/bolt/src/test/java/org/neo4j/bolt/v1/transport/socket/SocketTransportHandlerTest.java @@ -24,39 +24,39 @@ import io.netty.channel.Channel; import io.netty.channel.ChannelHandlerContext; import org.junit.Test; + +import java.util.HashMap; +import java.util.Map; +import java.util.function.BiFunction; + import org.neo4j.bolt.transport.BoltProtocol; import org.neo4j.bolt.transport.SocketTransportHandler; -import org.neo4j.bolt.v1.runtime.SynchronousBoltWorker; import org.neo4j.bolt.v1.runtime.BoltStateMachine; +import org.neo4j.bolt.v1.runtime.SynchronousBoltWorker; import org.neo4j.bolt.v1.transport.BoltProtocolV1; import org.neo4j.kernel.impl.logging.NullLogService; import org.neo4j.logging.AssertableLogProvider; import org.neo4j.logging.NullLogProvider; -import java.util.HashMap; -import java.util.Map; -import java.util.function.BiFunction; - import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; -import static org.mockito.Mockito.*; +import static org.junit.Assert.assertSame; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.neo4j.bolt.transport.SocketTransportHandler.ProtocolChooser; import static org.neo4j.logging.AssertableLogProvider.inLog; public class SocketTransportHandlerTest { @Test - public void shouldCloseSessionOnChannelClose() throws Throwable + public void shouldCloseProtocolOnChannelInactive() throws Throwable { // Given - BoltStateMachine machine = mock(BoltStateMachine.class); - Channel ch = mock( Channel.class ); - ChannelHandlerContext ctx = mock( ChannelHandlerContext.class ); - when(ctx.channel()).thenReturn( ch ); - - when( ch.alloc() ).thenReturn( UnpooledByteBufAllocator.DEFAULT ); - when( ctx.alloc() ).thenReturn( UnpooledByteBufAllocator.DEFAULT ); + BoltStateMachine machine = mock( BoltStateMachine.class ); + ChannelHandlerContext ctx = channelHandlerContextMock(); - SocketTransportHandler handler = new SocketTransportHandler( protocolChooser( machine ), NullLogProvider.getInstance() ); + SocketTransportHandler handler = newSocketTransportHandler( protocolChooser( machine ) ); // And Given a session has been established handler.channelRead( ctx, handshake() ); @@ -69,17 +69,58 @@ public void shouldCloseSessionOnChannelClose() throws Throwable } @Test - public void logsAndClosesConnectionOnUnexpectedExceptions() throws Throwable + public void shouldCloseContextWhenProtocolNotInitializedOnChannelInactive() throws Throwable + { + // Given + ChannelHandlerContext context = mock( ChannelHandlerContext.class ); + SocketTransportHandler handler = newSocketTransportHandler( mock( ProtocolChooser.class ) ); + + // When + handler.channelInactive( context ); + + // Then + verify( context ).close(); + } + + @Test + public void shouldCloseProtocolOnHandlerRemoved() throws Throwable + { + // Given + BoltStateMachine machine = mock( BoltStateMachine.class ); + ChannelHandlerContext ctx = channelHandlerContextMock(); + + SocketTransportHandler handler = newSocketTransportHandler( protocolChooser( machine ) ); + + // And Given a session has been established + handler.channelRead( ctx, handshake() ); + + // When + handler.handlerRemoved( ctx ); + + // Then + verify( machine ).close(); + } + + @Test + public void shouldCloseContextWhenProtocolNotInitializedOnHandlerRemoved() throws Throwable { // Given - BoltStateMachine machine = mock(BoltStateMachine.class); - Channel ch = mock( Channel.class ); - ChannelHandlerContext ctx = mock( ChannelHandlerContext.class ); - when(ctx.channel()).thenReturn( ch ); + ChannelHandlerContext context = mock( ChannelHandlerContext.class ); + SocketTransportHandler handler = newSocketTransportHandler( mock( ProtocolChooser.class ) ); - when( ch.alloc() ).thenReturn( UnpooledByteBufAllocator.DEFAULT ); - when( ctx.alloc() ).thenReturn( UnpooledByteBufAllocator.DEFAULT ); + // When + handler.handlerRemoved( context ); + // Then + verify( context ).close(); + } + + @Test + public void logsAndClosesProtocolOnUnexpectedExceptions() throws Throwable + { + // Given + BoltStateMachine machine = mock( BoltStateMachine.class ); + ChannelHandlerContext ctx = channelHandlerContextMock(); AssertableLogProvider logging = new AssertableLogProvider(); SocketTransportHandler handler = new SocketTransportHandler( protocolChooser( machine ), logging ); @@ -94,17 +135,72 @@ public void logsAndClosesConnectionOnUnexpectedExceptions() throws Throwable // Then verify( machine ).close(); logging.assertExactly( inLog( SocketTransportHandler.class ) - .error( equalTo("Fatal error occurred when handling a client connection: Oh no!"), is(cause) ) ); + .error( equalTo( "Fatal error occurred when handling a client connection: Oh no!" ), is( cause ) ) ); + } + + @Test + public void logsAndClosesContextWhenProtocolNotInitializedOnUnexpectedExceptions() throws Throwable + { + // Given + ChannelHandlerContext context = mock( ChannelHandlerContext.class ); + AssertableLogProvider logging = new AssertableLogProvider(); + SocketTransportHandler handler = new SocketTransportHandler( mock( ProtocolChooser.class ), logging ); + + // When + Throwable cause = new Throwable( "Oh no!" ); + handler.exceptionCaught( context, cause ); + + // Then + verify( context ).close(); + logging.assertExactly( inLog( SocketTransportHandler.class ) + .error( equalTo( "Fatal error occurred when handling a client connection: Oh no!" ), + is( cause ) ) ); + } + + @Test + public void shouldInitializeProtocolOnFirstMessage() throws Exception + { + BoltStateMachine machine = mock( BoltStateMachine.class ); + ProtocolChooser chooser = protocolChooser( machine ); + ChannelHandlerContext context = channelHandlerContextMock(); + + SocketTransportHandler handler = new SocketTransportHandler( chooser, NullLogProvider.getInstance() ); + + handler.channelRead( context, handshake() ); + BoltProtocol protocol1 = chooser.chosenProtocol(); + + handler.channelRead( context, handshake() ); + BoltProtocol protocol2 = chooser.chosenProtocol(); + + assertSame( protocol1, protocol2 ); + } + + private static SocketTransportHandler newSocketTransportHandler( ProtocolChooser protocolChooser ) + { + return new SocketTransportHandler( protocolChooser, NullLogProvider.getInstance() ); + } + + private static ChannelHandlerContext channelHandlerContextMock() + { + Channel channel = mock( Channel.class ); + ChannelHandlerContext context = mock( ChannelHandlerContext.class ); + when( context.channel() ).thenReturn( channel ); + + when( channel.alloc() ).thenReturn( UnpooledByteBufAllocator.DEFAULT ); + when( context.alloc() ).thenReturn( UnpooledByteBufAllocator.DEFAULT ); + + return context; } - private SocketTransportHandler.ProtocolChooser protocolChooser( final BoltStateMachine machine ) + private ProtocolChooser protocolChooser( final BoltStateMachine machine ) { - Map> availableVersions = new HashMap<>(); + Map> availableVersions = new HashMap<>(); availableVersions.put( (long) BoltProtocolV1.VERSION, - ( channel, isSecure ) -> new BoltProtocolV1( new SynchronousBoltWorker( machine ), channel, NullLogService.getInstance() ) + ( channel, isSecure ) -> new BoltProtocolV1( new SynchronousBoltWorker( machine ), channel, + NullLogService.getInstance() ) ); - return new SocketTransportHandler.ProtocolChooser( availableVersions, false, true ); + return new ProtocolChooser( availableVersions, false, true ); } private ByteBuf handshake() diff --git a/integrationtests/pom.xml b/integrationtests/pom.xml index ec85a22d523b..874ac2aae362 100644 --- a/integrationtests/pom.xml +++ b/integrationtests/pom.xml @@ -171,6 +171,10 @@ + + org.mockito + mockito-core + org.neo4j.driver neo4j-java-driver diff --git a/integrationtests/src/test/java/org/neo4j/bolt/BoltFailuresIT.java b/integrationtests/src/test/java/org/neo4j/bolt/BoltFailuresIT.java new file mode 100644 index 000000000000..130f09c77452 --- /dev/null +++ b/integrationtests/src/test/java/org/neo4j/bolt/BoltFailuresIT.java @@ -0,0 +1,110 @@ +/* + * Copyright (c) 2002-2016 "Neo Technology," + * Network Engine for Objects in Lund AB [http://neotechnology.com] + * + * This file is part of Neo4j. + * + * Neo4j is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +package org.neo4j.bolt; + +import org.junit.After; +import org.junit.Rule; +import org.junit.Test; + +import org.neo4j.bolt.v1.runtime.BoltFactory; +import org.neo4j.bolt.v1.runtime.WorkerFactory; +import org.neo4j.driver.v1.Driver; +import org.neo4j.driver.v1.GraphDatabase; +import org.neo4j.driver.v1.exceptions.ConnectionFailureException; +import org.neo4j.graphdb.GraphDatabaseService; +import org.neo4j.graphdb.factory.GraphDatabaseSettings; +import org.neo4j.kernel.impl.logging.LogService; +import org.neo4j.kernel.impl.util.JobScheduler; +import org.neo4j.test.rule.TestDirectory; + +import static org.hamcrest.Matchers.instanceOf; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.fail; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyString; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.neo4j.graphdb.factory.GraphDatabaseSettings.Connector.ConnectorType.BOLT; +import static org.neo4j.graphdb.factory.GraphDatabaseSettings.boltConnector; +import static org.neo4j.kernel.configuration.Settings.FALSE; +import static org.neo4j.kernel.configuration.Settings.TRUE; + +public class BoltFailuresIT +{ + @Rule + public final TestDirectory dir = TestDirectory.testDirectory(); + + private GraphDatabaseService db; + + @After + public void shutdownDb() + { + if ( db != null ) + { + db.shutdown(); + } + } + + @Test( timeout = 20_000 ) + public void throwsWhenSessionCreationFails() + { + WorkerFactory workerFactory = mock( WorkerFactory.class ); + when( workerFactory.newWorker( anyString(), any() ) ).thenThrow( new IllegalStateException( "Oh!" ) ); + + db = newDbFactory( new BoltKernelExtensionWithWorkerFactory( workerFactory ) ); + + try ( Driver driver = GraphDatabase.driver( "bolt://localhost" ) ) + { + driver.session(); + fail( "Exception expected" ); + } + catch ( Exception e ) + { + assertThat( e, instanceOf( ConnectionFailureException.class ) ); + } + } + + private GraphDatabaseService newDbFactory( BoltKernelExtension boltKernelExtension ) + { + return new GraphDatabaseFactoryWithCustomBoltKernelExtension( boltKernelExtension ) + .newEmbeddedDatabaseBuilder( dir.graphDbDir() ) + .setConfig( boltConnector( "0" ).type, BOLT.name() ) + .setConfig( boltConnector( "0" ).enabled, TRUE ) + .setConfig( GraphDatabaseSettings.auth_enabled, FALSE ) + .newGraphDatabase(); + } + + private static class BoltKernelExtensionWithWorkerFactory extends BoltKernelExtension + { + final WorkerFactory workerFactory; + + BoltKernelExtensionWithWorkerFactory( WorkerFactory workerFactory ) + { + this.workerFactory = workerFactory; + } + + @Override + protected WorkerFactory createWorkerFactory( BoltFactory boltFactory, JobScheduler scheduler, + Dependencies dependencies, LogService logService ) + { + return workerFactory; + } + } +} diff --git a/integrationtests/src/test/java/org/neo4j/bolt/GraphDatabaseFactoryWithCustomBoltKernelExtension.java b/integrationtests/src/test/java/org/neo4j/bolt/GraphDatabaseFactoryWithCustomBoltKernelExtension.java new file mode 100644 index 000000000000..fc97c4bb557f --- /dev/null +++ b/integrationtests/src/test/java/org/neo4j/bolt/GraphDatabaseFactoryWithCustomBoltKernelExtension.java @@ -0,0 +1,164 @@ +/* + * Copyright (c) 2002-2016 "Neo Technology," + * Network Engine for Objects in Lund AB [http://neotechnology.com] + * + * This file is part of Neo4j. + * + * Neo4j is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ +package org.neo4j.bolt; + +import java.io.File; +import java.util.Map; + +import org.neo4j.graphdb.GraphDatabaseService; +import org.neo4j.graphdb.factory.GraphDatabaseFactory; +import org.neo4j.graphdb.security.URLAccessRule; +import org.neo4j.helpers.collection.Iterables; +import org.neo4j.kernel.extension.KernelExtensionFactory; +import org.neo4j.kernel.extension.KernelExtensions; +import org.neo4j.kernel.impl.enterprise.EnterpriseEditionModule; +import org.neo4j.kernel.impl.factory.DatabaseInfo; +import org.neo4j.kernel.impl.factory.GraphDatabaseFacade; +import org.neo4j.kernel.impl.factory.GraphDatabaseFacadeFactory; +import org.neo4j.kernel.impl.factory.PlatformModule; +import org.neo4j.kernel.impl.query.QueryEngineProvider; +import org.neo4j.kernel.impl.spi.KernelContext; +import org.neo4j.kernel.lifecycle.Lifecycle; +import org.neo4j.kernel.monitoring.Monitors; +import org.neo4j.logging.LogProvider; + +import static java.util.stream.Collectors.toList; +import static org.neo4j.kernel.impl.factory.GraphDatabaseFacadeFactory.Dependencies; + +public class GraphDatabaseFactoryWithCustomBoltKernelExtension extends GraphDatabaseFactory +{ + private final BoltKernelExtension customExtension; + + public GraphDatabaseFactoryWithCustomBoltKernelExtension( BoltKernelExtension customExtension ) + { + this.customExtension = customExtension; + } + + @Override + protected GraphDatabaseService newDatabase( File storeDir, Map config, Dependencies dependencies ) + { + GraphDatabaseFacadeFactory factory = new CustomBoltKernelExtensionFacadeFactory( customExtension ); + return factory.newFacade( storeDir, config, dependencies ); + } + + private static class CustomBoltKernelExtensionFacadeFactory extends GraphDatabaseFacadeFactory + { + final BoltKernelExtension customExtension; + + CustomBoltKernelExtensionFacadeFactory( BoltKernelExtension customExtension ) + { + super( DatabaseInfo.ENTERPRISE, EnterpriseEditionModule::new ); + this.customExtension = customExtension; + } + + @Override + protected PlatformModule createPlatform( File storeDir, Map params, + Dependencies dependencies, GraphDatabaseFacade graphDatabaseFacade ) + { + Dependencies newDependencies = new CustomBoltKernelExtensionDependencies( customExtension, dependencies ); + return new PlatformModule( storeDir, params, databaseInfo, newDependencies, graphDatabaseFacade ); + } + } + + private static class CustomBoltKernelExtensionDependencies implements Dependencies + { + final BoltKernelExtension customExtension; + final Dependencies delegate; + + CustomBoltKernelExtensionDependencies( BoltKernelExtension customExtension, Dependencies delegate ) + { + this.customExtension = customExtension; + this.delegate = delegate; + } + + @Override + public Monitors monitors() + { + return delegate.monitors(); + } + + @Override + public LogProvider userLogProvider() + { + return delegate.userLogProvider(); + } + + @Override + public Iterable> settingsClasses() + { + return delegate.settingsClasses(); + } + + @Override + public Iterable> kernelExtensions() + { + return Iterables.stream( delegate.kernelExtensions() ) + .map( this::replaceBoltKernelExtensionFactory ) + .collect( toList() ); + } + + @Override + public Map urlAccessRules() + { + return delegate.urlAccessRules(); + } + + @Override + public Iterable executionEngines() + { + return delegate.executionEngines(); + } + + KernelExtensionFactory replaceBoltKernelExtensionFactory( KernelExtensionFactory factory ) + { + if ( factory instanceof BoltKernelExtension ) + { + return new CustomBoltKernelExtension( customExtension ); + } + return factory; + } + } + + /** + * Each kernel extension factory is expected to extend {@link KernelExtensionFactory} and have some dependencies + * as it's type parameter. That is why we can't use given custom extension as is, it can extend a real + * {@link BoltKernelExtension}. So this wrapper delegates to the given extension and has same superclass as the + * real {@link BoltKernelExtension}. + * + * @see KernelExtensions#getKernelExtensionDependencies(KernelExtensionFactory) + */ + private static class CustomBoltKernelExtension extends KernelExtensionFactory + { + final BoltKernelExtension customExtension; + + CustomBoltKernelExtension( BoltKernelExtension customExtension ) + { + super( "custom-bolt-server" ); + this.customExtension = customExtension; + } + + @Override + public Lifecycle newInstance( KernelContext context, BoltKernelExtension.Dependencies dependencies ) + throws Throwable + { + return customExtension.newInstance( context, dependencies ); + } + } +}