Skip to content

Commit

Permalink
Fix buffer leak in Bolt output
Browse files Browse the repository at this point in the history
BoltProtocolV1 is a container for inbound and outbound Netty-based connections.
There exists a protocol instance for each neo4j client. Previously outbound
connection, represented by ChunkedOutput class, was passed in as a parameter
and never closed. This caused a native memory leak because connections
allocate native byte buffers to put them into channel.

This commit makes BoltProtocolV1 responsible for ChunkedOutput lifecycle. It
is created in constructor and closed when the protocol instance is closed.
  • Loading branch information
lutovich committed Jul 27, 2016
1 parent c82f2e6 commit 173ac0d
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 27 deletions.
Expand Up @@ -39,15 +39,13 @@
import org.neo4j.bolt.transport.NettyServer; import org.neo4j.bolt.transport.NettyServer;
import org.neo4j.bolt.transport.NettyServer.ProtocolInitializer; import org.neo4j.bolt.transport.NettyServer.ProtocolInitializer;
import org.neo4j.bolt.transport.SocketTransport; import org.neo4j.bolt.transport.SocketTransport;
import org.neo4j.bolt.v1.messaging.Neo4jPack;
import org.neo4j.bolt.v1.messaging.PackStreamMessageFormatV1;
import org.neo4j.bolt.v1.runtime.MonitoredSessions; import org.neo4j.bolt.v1.runtime.MonitoredSessions;
import org.neo4j.bolt.v1.runtime.Session;
import org.neo4j.bolt.v1.runtime.Sessions; import org.neo4j.bolt.v1.runtime.Sessions;
import org.neo4j.bolt.v1.runtime.internal.EncryptionRequiredSessions; import org.neo4j.bolt.v1.runtime.internal.EncryptionRequiredSessions;
import org.neo4j.bolt.v1.runtime.internal.StandardSessions; import org.neo4j.bolt.v1.runtime.internal.StandardSessions;
import org.neo4j.bolt.v1.runtime.internal.concurrent.ThreadedSessions; import org.neo4j.bolt.v1.runtime.internal.concurrent.ThreadedSessions;
import org.neo4j.bolt.v1.transport.BoltProtocolV1; import org.neo4j.bolt.v1.transport.BoltProtocolV1;
import org.neo4j.bolt.v1.transport.ChunkedOutput;
import org.neo4j.collection.primitive.PrimitiveLongObjectMap; import org.neo4j.collection.primitive.PrimitiveLongObjectMap;
import org.neo4j.graphdb.GraphDatabaseService; import org.neo4j.graphdb.GraphDatabaseService;
import org.neo4j.graphdb.config.Configuration; import org.neo4j.graphdb.config.Configuration;
Expand Down Expand Up @@ -218,9 +216,8 @@ private PrimitiveLongObjectMap<BiFunction<Channel,Boolean,BoltProtocol>> newVers
BoltProtocolV1.VERSION, BoltProtocolV1.VERSION,
( channel, isEncrypted ) -> { ( channel, isEncrypted ) -> {
String descriptor = format( "\tclient%s\tserver%s", channel.remoteAddress(), channel.localAddress() ); String descriptor = format( "\tclient%s\tserver%s", channel.remoteAddress(), channel.localAddress() );
ChunkedOutput output = new ChunkedOutput( channel, 8192 ); Session session = sessions.newSession( descriptor, isEncrypted );
return new BoltProtocolV1( logging, sessions.newSession( descriptor, isEncrypted ), return new BoltProtocolV1( session, channel, logging );
new PackStreamMessageFormatV1.Writer( new Neo4jPack.Packer( output ), output ) );
} }
); );
return availableVersions; return availableVersions;
Expand Down
Expand Up @@ -20,13 +20,15 @@
package org.neo4j.bolt.v1.transport; package org.neo4j.bolt.v1.transport;


import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBuf;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;


import java.io.IOException; import java.io.IOException;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;


import org.neo4j.bolt.transport.BoltProtocol; import org.neo4j.bolt.transport.BoltProtocol;
import org.neo4j.bolt.v1.messaging.MessageFormat; import org.neo4j.bolt.v1.messaging.MessageFormat;
import org.neo4j.bolt.v1.messaging.Neo4jPack;
import org.neo4j.bolt.v1.messaging.PackStreamMessageFormatV1; import org.neo4j.bolt.v1.messaging.PackStreamMessageFormatV1;
import org.neo4j.bolt.v1.messaging.msgprocess.TransportBridge; import org.neo4j.bolt.v1.messaging.msgprocess.TransportBridge;
import org.neo4j.bolt.v1.runtime.Session; import org.neo4j.bolt.v1.runtime.Session;
Expand All @@ -42,6 +44,10 @@
public class BoltProtocolV1 implements BoltProtocol public class BoltProtocolV1 implements BoltProtocol
{ {
public static final int VERSION = 1; public static final int VERSION = 1;

private static final int DEFAULT_OUTPUT_BUFFER_SIZE = 8192;

private final ChunkedOutput chunkedOutput;
private final MessageFormat.Writer packer; private final MessageFormat.Writer packer;
private final BoltV1Dechunker dechunker; private final BoltV1Dechunker dechunker;


Expand All @@ -50,14 +56,15 @@ public class BoltProtocolV1 implements BoltProtocol
private final AtomicInteger inFlight = new AtomicInteger( 0 ); private final AtomicInteger inFlight = new AtomicInteger( 0 );
private final TransportBridge bridge; private final TransportBridge bridge;


public BoltProtocolV1( final LogService logging, Session session, PackStreamMessageFormatV1.Writer output ) public BoltProtocolV1( Session session, Channel outputChannel, LogService logging )
{ {
// TODO; this part of the Bolt server side is rather messy - notably, the MessageHandler, Session and Session.Callback interfaces all // TODO; this part of the Bolt server side is rather messy - notably, the MessageHandler, Session and Session.Callback interfaces all
// should reasonably be able to be refactored into something much less complicated. // should reasonably be able to be refactored into something much less complicated.
// Likewise the tracking of when to flush the outbound channel - if we moved that logic to ThreadedSessions, a lot of the complexity // Likewise the tracking of when to flush the outbound channel - if we moved that logic to ThreadedSessions, a lot of the complexity
// below could likely be undone. // below could likely be undone.
this.chunkedOutput = new ChunkedOutput( outputChannel, DEFAULT_OUTPUT_BUFFER_SIZE );
this.packer = new PackStreamMessageFormatV1.Writer( new Neo4jPack.Packer( chunkedOutput ), chunkedOutput );
this.session = session; this.session = session;
this.packer = output;
this.bridge = new TransportBridge( logging.getInternalLog( getClass() ), session, packer, this::onMessageDone ); this.bridge = new TransportBridge( logging.getInternalLog( getClass() ), session, packer, this::onMessageDone );
this.dechunker = new BoltV1Dechunker( bridge, this::onMessageStarted ); this.dechunker = new BoltV1Dechunker( bridge, this::onMessageStarted );
} }
Expand All @@ -78,7 +85,10 @@ public void handle( ChannelHandlerContext channelContext, ByteBuf data ) throws
catch ( Throwable e ) catch ( Throwable e )
{ {
bridge.handleFatalError( Neo4jError.from( e ) ); bridge.handleFatalError( Neo4jError.from( e ) );
close();
// close input, keep output open. we still need to write error back to the client higher in the
// call stack. we are not going to read anything after the error.
closeInput();
} }
finally finally
{ {
Expand All @@ -94,11 +104,22 @@ public int version()


@Override @Override
public synchronized void close() public synchronized void close()
{
closeInput();
closeOutput();
}

private void closeInput()
{ {
dechunker.close(); dechunker.close();
session.close(); session.close();
} }


private void closeOutput()
{
chunkedOutput.close();
}

/* /*
* Ths methods below are used to track in-flight messages (messages the client has sent us that are waiting to be processed). We use this information * Ths methods below are used to track in-flight messages (messages the client has sent us that are waiting to be processed). We use this information
* to determine when to explicitly flush our output buffers - if there are no more pending messages when a message is done processing, we should flush * to determine when to explicitly flush our output buffers - if there are no more pending messages when a message is done processing, we should flush
Expand Down
Expand Up @@ -20,30 +20,38 @@
package org.neo4j.bolt.v1.transport; package org.neo4j.bolt.v1.transport;


import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
import org.junit.Test; import org.junit.Test;


import java.io.IOException; import java.io.IOException;


import org.neo4j.bolt.v1.messaging.PackStreamMessageFormatV1;
import org.neo4j.bolt.v1.runtime.Session; import org.neo4j.bolt.v1.runtime.Session;
import org.neo4j.kernel.impl.logging.NullLogService; import org.neo4j.kernel.impl.logging.NullLogService;


import static org.mockito.Matchers.any; import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyInt;
import static org.mockito.Mockito.RETURNS_MOCKS;
import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;


public class BoltProtocolV1Test public class BoltProtocolV1Test
{ {
@Test @Test
public void shouldNotTalkToChannelDirectlyOnFatalError() throws Throwable public void shouldNotTalkToChannelDirectlyOnFatalError() throws Throwable
{ {
// Given // Given
PackStreamMessageFormatV1.Writer output = mock( PackStreamMessageFormatV1.Writer.class ); Channel outputChannel = mock( Channel.class );
ByteBufAllocator allocator = mock( ByteBufAllocator.class, RETURNS_MOCKS );
when( outputChannel.alloc() ).thenReturn( allocator );

Session session = mock( Session.class ); Session session = mock( Session.class );
BoltProtocolV1 protocol = new BoltProtocolV1( NullLogService.getInstance(), session, output ); BoltProtocolV1 protocol = new BoltProtocolV1( session, outputChannel, NullLogService.getInstance() );
verify( outputChannel ).alloc();


// And given inbound data that'll explode when the protocol tries to interpret it // And given inbound data that'll explode when the protocol tries to interpret it
ByteBuf bomb = mock(ByteBuf.class); ByteBuf bomb = mock(ByteBuf.class);
Expand All @@ -53,10 +61,28 @@ public void shouldNotTalkToChannelDirectlyOnFatalError() throws Throwable
protocol.handle( mock(ChannelHandlerContext.class), bomb ); protocol.handle( mock(ChannelHandlerContext.class), bomb );


// Then the protocol should not mess with the channel (because it runs on the IO thread, and only the worker thread should produce writes) // Then the protocol should not mess with the channel (because it runs on the IO thread, and only the worker thread should produce writes)
verifyNoMoreInteractions( output ); verifyNoMoreInteractions( outputChannel );


// But instead signal to the session that shit hit the fan. // But instead signal to the session that shit hit the fan.
verify( session ).externalError( any(), any(), any() ); verify( session ).externalError( any(), any(), any() );
verify( session ).close(); verify( session ).close();
} }
}
@Test
public void closesInputAndOutput()
{
Channel outputChannel = mock( Channel.class );
ByteBufAllocator allocator = mock( ByteBufAllocator.class );
ByteBuf buffer = mock( ByteBuf.class );
when( outputChannel.alloc() ).thenReturn( allocator );
when( allocator.buffer( anyInt(), anyInt() ) ).thenReturn( buffer );

Session session = mock( Session.class );

BoltProtocolV1 protocol = new BoltProtocolV1( session, outputChannel, NullLogService.getInstance() );
protocol.close();

verify( session ).close();
verify( buffer ).release();
}
}
Expand Up @@ -36,7 +36,6 @@
import org.neo4j.bolt.v1.packstream.BufferedChannelOutput; import org.neo4j.bolt.v1.packstream.BufferedChannelOutput;
import org.neo4j.bolt.v1.runtime.Session; import org.neo4j.bolt.v1.runtime.Session;
import org.neo4j.bolt.v1.transport.BoltProtocolV1; import org.neo4j.bolt.v1.transport.BoltProtocolV1;
import org.neo4j.bolt.v1.transport.ChunkedOutput;
import org.neo4j.kernel.impl.logging.NullLogService; import org.neo4j.kernel.impl.logging.NullLogService;
import org.neo4j.kernel.impl.util.HexPrinter; import org.neo4j.kernel.impl.util.HexPrinter;


Expand Down Expand Up @@ -116,9 +115,7 @@ private void testPermutation( byte[] unfragmented, ByteBuf[] fragments ) throws
ChannelHandlerContext ctx = mock( ChannelHandlerContext.class ); ChannelHandlerContext ctx = mock( ChannelHandlerContext.class );
when(ctx.channel()).thenReturn( ch ); when(ctx.channel()).thenReturn( ch );


ChunkedOutput output = new ChunkedOutput( ch, 8192 ); BoltProtocolV1 protocol = new BoltProtocolV1( sess, ch, NullLogService.getInstance() );
BoltProtocolV1 protocol = new BoltProtocolV1( NullLogService.getInstance(), sess,
new PackStreamMessageFormatV1.Writer( new Neo4jPack.Packer( output ), output ) );


// When data arrives split up according to the current permutation // When data arrives split up according to the current permutation
for ( ByteBuf fragment : fragments ) for ( ByteBuf fragment : fragments )
Expand Down Expand Up @@ -172,4 +169,4 @@ private byte[] serialize( int chunkSize, Message... msgs ) throws IOException
} }
return Chunker.chunk( chunkSize, serialized ); return Chunker.chunk( chunkSize, serialized );
} }
} }
Expand Up @@ -29,11 +29,8 @@


import org.neo4j.bolt.transport.BoltProtocol; import org.neo4j.bolt.transport.BoltProtocol;
import org.neo4j.bolt.transport.SocketTransportHandler; import org.neo4j.bolt.transport.SocketTransportHandler;
import org.neo4j.bolt.v1.messaging.Neo4jPack;
import org.neo4j.bolt.v1.messaging.PackStreamMessageFormatV1;
import org.neo4j.bolt.v1.runtime.Session; import org.neo4j.bolt.v1.runtime.Session;
import org.neo4j.bolt.v1.transport.BoltProtocolV1; import org.neo4j.bolt.v1.transport.BoltProtocolV1;
import org.neo4j.bolt.v1.transport.ChunkedOutput;
import org.neo4j.collection.primitive.PrimitiveLongObjectMap; import org.neo4j.collection.primitive.PrimitiveLongObjectMap;
import org.neo4j.kernel.impl.logging.NullLogService; import org.neo4j.kernel.impl.logging.NullLogService;
import org.neo4j.logging.AssertableLogProvider; import org.neo4j.logging.AssertableLogProvider;
Expand Down Expand Up @@ -106,11 +103,7 @@ private SocketTransportHandler.ProtocolChooser protocolChooser( final Session se
{ {
PrimitiveLongObjectMap<BiFunction<Channel,Boolean,BoltProtocol>> availableVersions = longObjectMap(); PrimitiveLongObjectMap<BiFunction<Channel,Boolean,BoltProtocol>> availableVersions = longObjectMap();
availableVersions.put( BoltProtocolV1.VERSION, availableVersions.put( BoltProtocolV1.VERSION,
( channel, isSecure ) -> { ( channel, isSecure ) -> new BoltProtocolV1( session, channel, NullLogService.getInstance() )
ChunkedOutput output = new ChunkedOutput( channel, 8192 );
return new BoltProtocolV1( NullLogService.getInstance(), session,
new PackStreamMessageFormatV1.Writer( new Neo4jPack.Packer( output ), output ));
}
); );


return new SocketTransportHandler.ProtocolChooser( availableVersions, true ); return new SocketTransportHandler.ProtocolChooser( availableVersions, true );
Expand Down

0 comments on commit 173ac0d

Please sign in to comment.