diff --git a/community/bolt/src/main/java/org/neo4j/bolt/BoltKernelExtension.java b/community/bolt/src/main/java/org/neo4j/bolt/BoltKernelExtension.java index 7d1f033f1e844..aa03e0e06d331 100644 --- a/community/bolt/src/main/java/org/neo4j/bolt/BoltKernelExtension.java +++ b/community/bolt/src/main/java/org/neo4j/bolt/BoltKernelExtension.java @@ -39,15 +39,13 @@ import org.neo4j.bolt.transport.NettyServer; import org.neo4j.bolt.transport.NettyServer.ProtocolInitializer; import org.neo4j.bolt.transport.SocketTransport; -import org.neo4j.bolt.v1.messaging.Neo4jPack; -import org.neo4j.bolt.v1.messaging.PackStreamMessageFormatV1; import org.neo4j.bolt.v1.runtime.MonitoredSessions; +import org.neo4j.bolt.v1.runtime.Session; import org.neo4j.bolt.v1.runtime.Sessions; import org.neo4j.bolt.v1.runtime.internal.EncryptionRequiredSessions; import org.neo4j.bolt.v1.runtime.internal.StandardSessions; import org.neo4j.bolt.v1.runtime.internal.concurrent.ThreadedSessions; import org.neo4j.bolt.v1.transport.BoltProtocolV1; -import org.neo4j.bolt.v1.transport.ChunkedOutput; import org.neo4j.collection.primitive.PrimitiveLongObjectMap; import org.neo4j.graphdb.GraphDatabaseService; import org.neo4j.graphdb.config.Configuration; @@ -218,9 +216,8 @@ private PrimitiveLongObjectMap> newVers BoltProtocolV1.VERSION, ( channel, isEncrypted ) -> { String descriptor = format( "\tclient%s\tserver%s", channel.remoteAddress(), channel.localAddress() ); - ChunkedOutput output = new ChunkedOutput( channel, 8192 ); - return new BoltProtocolV1( logging, sessions.newSession( descriptor, isEncrypted ), - new PackStreamMessageFormatV1.Writer( new Neo4jPack.Packer( output ), output ) ); + Session session = sessions.newSession( descriptor, isEncrypted ); + return new BoltProtocolV1( session, channel, logging ); } ); return availableVersions; diff --git a/community/bolt/src/main/java/org/neo4j/bolt/v1/transport/BoltProtocolV1.java b/community/bolt/src/main/java/org/neo4j/bolt/v1/transport/BoltProtocolV1.java index 40c4bd8635d0e..e1b5918016661 100644 --- a/community/bolt/src/main/java/org/neo4j/bolt/v1/transport/BoltProtocolV1.java +++ b/community/bolt/src/main/java/org/neo4j/bolt/v1/transport/BoltProtocolV1.java @@ -20,6 +20,7 @@ package org.neo4j.bolt.v1.transport; import io.netty.buffer.ByteBuf; +import io.netty.channel.Channel; import io.netty.channel.ChannelHandlerContext; import java.io.IOException; @@ -27,6 +28,7 @@ import org.neo4j.bolt.transport.BoltProtocol; import org.neo4j.bolt.v1.messaging.MessageFormat; +import org.neo4j.bolt.v1.messaging.Neo4jPack; import org.neo4j.bolt.v1.messaging.PackStreamMessageFormatV1; import org.neo4j.bolt.v1.messaging.msgprocess.TransportBridge; import org.neo4j.bolt.v1.runtime.Session; @@ -42,6 +44,10 @@ public class BoltProtocolV1 implements BoltProtocol { public static final int VERSION = 1; + + private static final int DEFAULT_OUTPUT_BUFFER_SIZE = 8192; + + private final ChunkedOutput chunkedOutput; private final MessageFormat.Writer packer; private final BoltV1Dechunker dechunker; @@ -50,14 +56,15 @@ public class BoltProtocolV1 implements BoltProtocol private final AtomicInteger inFlight = new AtomicInteger( 0 ); private final TransportBridge bridge; - public BoltProtocolV1( final LogService logging, Session session, PackStreamMessageFormatV1.Writer output ) + public BoltProtocolV1( Session session, Channel outputChannel, LogService logging ) { // TODO; this part of the Bolt server side is rather messy - notably, the MessageHandler, Session and Session.Callback interfaces all // should reasonably be able to be refactored into something much less complicated. // Likewise the tracking of when to flush the outbound channel - if we moved that logic to ThreadedSessions, a lot of the complexity // below could likely be undone. + this.chunkedOutput = new ChunkedOutput( outputChannel, DEFAULT_OUTPUT_BUFFER_SIZE ); + this.packer = new PackStreamMessageFormatV1.Writer( new Neo4jPack.Packer( chunkedOutput ), chunkedOutput ); this.session = session; - this.packer = output; this.bridge = new TransportBridge( logging.getInternalLog( getClass() ), session, packer, this::onMessageDone ); this.dechunker = new BoltV1Dechunker( bridge, this::onMessageStarted ); } @@ -78,7 +85,10 @@ public void handle( ChannelHandlerContext channelContext, ByteBuf data ) throws catch ( Throwable e ) { bridge.handleFatalError( Neo4jError.from( e ) ); - close(); + + // close input, keep output open. we still need to write error back to the client higher in the + // call stack. we are not going to read anything after the error. + closeInput(); } finally { @@ -94,11 +104,22 @@ public int version() @Override public synchronized void close() + { + closeInput(); + closeOutput(); + } + + private void closeInput() { dechunker.close(); session.close(); } + private void closeOutput() + { + chunkedOutput.close(); + } + /* * Ths methods below are used to track in-flight messages (messages the client has sent us that are waiting to be processed). We use this information * to determine when to explicitly flush our output buffers - if there are no more pending messages when a message is done processing, we should flush diff --git a/community/bolt/src/test/java/org/neo4j/bolt/v1/transport/BoltProtocolV1Test.java b/community/bolt/src/test/java/org/neo4j/bolt/v1/transport/BoltProtocolV1Test.java index 82db2a40ce512..fe73a37bac9f1 100644 --- a/community/bolt/src/test/java/org/neo4j/bolt/v1/transport/BoltProtocolV1Test.java +++ b/community/bolt/src/test/java/org/neo4j/bolt/v1/transport/BoltProtocolV1Test.java @@ -20,20 +20,24 @@ package org.neo4j.bolt.v1.transport; import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.channel.Channel; import io.netty.channel.ChannelHandlerContext; import org.junit.Test; import java.io.IOException; -import org.neo4j.bolt.v1.messaging.PackStreamMessageFormatV1; import org.neo4j.bolt.v1.runtime.Session; import org.neo4j.kernel.impl.logging.NullLogService; import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyInt; +import static org.mockito.Mockito.RETURNS_MOCKS; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; public class BoltProtocolV1Test { @@ -41,9 +45,13 @@ public class BoltProtocolV1Test public void shouldNotTalkToChannelDirectlyOnFatalError() throws Throwable { // Given - PackStreamMessageFormatV1.Writer output = mock( PackStreamMessageFormatV1.Writer.class ); + Channel outputChannel = mock( Channel.class ); + ByteBufAllocator allocator = mock( ByteBufAllocator.class, RETURNS_MOCKS ); + when( outputChannel.alloc() ).thenReturn( allocator ); + Session session = mock( Session.class ); - BoltProtocolV1 protocol = new BoltProtocolV1( NullLogService.getInstance(), session, output ); + BoltProtocolV1 protocol = new BoltProtocolV1( session, outputChannel, NullLogService.getInstance() ); + verify( outputChannel ).alloc(); // And given inbound data that'll explode when the protocol tries to interpret it ByteBuf bomb = mock(ByteBuf.class); @@ -53,10 +61,28 @@ public void shouldNotTalkToChannelDirectlyOnFatalError() throws Throwable protocol.handle( mock(ChannelHandlerContext.class), bomb ); // Then the protocol should not mess with the channel (because it runs on the IO thread, and only the worker thread should produce writes) - verifyNoMoreInteractions( output ); + verifyNoMoreInteractions( outputChannel ); // But instead signal to the session that shit hit the fan. verify( session ).externalError( any(), any(), any() ); verify( session ).close(); } -} \ No newline at end of file + + @Test + public void closesInputAndOutput() + { + Channel outputChannel = mock( Channel.class ); + ByteBufAllocator allocator = mock( ByteBufAllocator.class ); + ByteBuf buffer = mock( ByteBuf.class ); + when( outputChannel.alloc() ).thenReturn( allocator ); + when( allocator.buffer( anyInt(), anyInt() ) ).thenReturn( buffer ); + + Session session = mock( Session.class ); + + BoltProtocolV1 protocol = new BoltProtocolV1( session, outputChannel, NullLogService.getInstance() ); + protocol.close(); + + verify( session ).close(); + verify( buffer ).release(); + } +} diff --git a/community/bolt/src/test/java/org/neo4j/bolt/v1/transport/socket/FragmentedMessageDeliveryTest.java b/community/bolt/src/test/java/org/neo4j/bolt/v1/transport/socket/FragmentedMessageDeliveryTest.java index dd1fde42f0eba..a497ca18ead79 100644 --- a/community/bolt/src/test/java/org/neo4j/bolt/v1/transport/socket/FragmentedMessageDeliveryTest.java +++ b/community/bolt/src/test/java/org/neo4j/bolt/v1/transport/socket/FragmentedMessageDeliveryTest.java @@ -36,7 +36,6 @@ import org.neo4j.bolt.v1.packstream.BufferedChannelOutput; import org.neo4j.bolt.v1.runtime.Session; import org.neo4j.bolt.v1.transport.BoltProtocolV1; -import org.neo4j.bolt.v1.transport.ChunkedOutput; import org.neo4j.kernel.impl.logging.NullLogService; import org.neo4j.kernel.impl.util.HexPrinter; @@ -116,9 +115,7 @@ private void testPermutation( byte[] unfragmented, ByteBuf[] fragments ) throws ChannelHandlerContext ctx = mock( ChannelHandlerContext.class ); when(ctx.channel()).thenReturn( ch ); - ChunkedOutput output = new ChunkedOutput( ch, 8192 ); - BoltProtocolV1 protocol = new BoltProtocolV1( NullLogService.getInstance(), sess, - new PackStreamMessageFormatV1.Writer( new Neo4jPack.Packer( output ), output ) ); + BoltProtocolV1 protocol = new BoltProtocolV1( sess, ch, NullLogService.getInstance() ); // When data arrives split up according to the current permutation for ( ByteBuf fragment : fragments ) @@ -172,4 +169,4 @@ private byte[] serialize( int chunkSize, Message... msgs ) throws IOException } return Chunker.chunk( chunkSize, serialized ); } -} \ No newline at end of file +} diff --git a/community/bolt/src/test/java/org/neo4j/bolt/v1/transport/socket/SocketTransportHandlerTest.java b/community/bolt/src/test/java/org/neo4j/bolt/v1/transport/socket/SocketTransportHandlerTest.java index ed0a05d8f82fb..0ae6a1db15551 100644 --- a/community/bolt/src/test/java/org/neo4j/bolt/v1/transport/socket/SocketTransportHandlerTest.java +++ b/community/bolt/src/test/java/org/neo4j/bolt/v1/transport/socket/SocketTransportHandlerTest.java @@ -29,11 +29,8 @@ import org.neo4j.bolt.transport.BoltProtocol; import org.neo4j.bolt.transport.SocketTransportHandler; -import org.neo4j.bolt.v1.messaging.Neo4jPack; -import org.neo4j.bolt.v1.messaging.PackStreamMessageFormatV1; import org.neo4j.bolt.v1.runtime.Session; import org.neo4j.bolt.v1.transport.BoltProtocolV1; -import org.neo4j.bolt.v1.transport.ChunkedOutput; import org.neo4j.collection.primitive.PrimitiveLongObjectMap; import org.neo4j.kernel.impl.logging.NullLogService; import org.neo4j.logging.AssertableLogProvider; @@ -106,11 +103,7 @@ private SocketTransportHandler.ProtocolChooser protocolChooser( final Session se { PrimitiveLongObjectMap> availableVersions = longObjectMap(); availableVersions.put( BoltProtocolV1.VERSION, - ( channel, isSecure ) -> { - ChunkedOutput output = new ChunkedOutput( channel, 8192 ); - return new BoltProtocolV1( NullLogService.getInstance(), session, - new PackStreamMessageFormatV1.Writer( new Neo4jPack.Packer( output ), output )); - } + ( channel, isSecure ) -> new BoltProtocolV1( session, channel, NullLogService.getInstance() ) ); return new SocketTransportHandler.ProtocolChooser( availableVersions, true );