Skip to content

Commit

Permalink
Close Netty channel on Bolt handshake failure
Browse files Browse the repository at this point in the history
Bolt protocol executes a handshake before establishing new session. New worker
is started for each session after protocol negotiation and handshake complete.
Creation of a worker can potentially fail with a runtime exception. In this
case corresponding Netty channel would just remain open and client (driver)
would hang waiting for a response.

This commit makes code close channels even when handshake fails.
  • Loading branch information
lutovich committed Jan 3, 2017
1 parent 86648ee commit 9b0c8d9
Show file tree
Hide file tree
Showing 6 changed files with 428 additions and 40 deletions.
Expand Up @@ -155,12 +155,10 @@ public Lifecycle newInstance( KernelContext context, Dependencies dependencies )

Authentication authentication = authentication( dependencies.authManager() );

BoltFactory boltConnectionManagerFactory = life.add(
new LifecycleManagedBoltFactory( api, dependencies.usageData(), logService, dependencies.txBridge(),
authentication, dependencies.sessionTracker() ) );
ThreadedWorkerFactory threadedSessions = new ThreadedWorkerFactory( boltConnectionManagerFactory, scheduler, logService );
WorkerFactory workerFactory = new MonitoredWorkerFactory( dependencies.monitors(), threadedSessions,
Clocks.systemClock() );
BoltFactory boltFactory = life.add( new LifecycleManagedBoltFactory( api, dependencies.usageData(),
logService, dependencies.txBridge(), authentication, dependencies.sessionTracker() ) );

WorkerFactory workerFactory = createWorkerFactory( boltFactory, scheduler, dependencies, logService );

List<ProtocolInitializer> connectors = boltConnectors( config ).stream()
.map( ( connConfig ) -> {
Expand Down Expand Up @@ -216,6 +214,13 @@ public Lifecycle newInstance( KernelContext context, Dependencies dependencies )
return life;
}

protected WorkerFactory createWorkerFactory( BoltFactory boltFactory, JobScheduler scheduler,
Dependencies dependencies, LogService logService )
{
WorkerFactory threadedWorkerFactory = new ThreadedWorkerFactory( boltFactory, scheduler, logService );
return new MonitoredWorkerFactory( dependencies.monitors(), threadedWorkerFactory, Clocks.systemClock() );
}

private SslContext createSslContext( Config config, Log log, AdvertisedSocketAddress address )
{
try
Expand Down
Expand Up @@ -23,14 +23,15 @@
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import org.neo4j.logging.Log;
import org.neo4j.logging.LogProvider;

import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.Map;
import java.util.function.BiFunction;

import org.neo4j.logging.Log;
import org.neo4j.logging.LogProvider;

import static io.netty.buffer.Unpooled.wrappedBuffer;

/**
Expand Down Expand Up @@ -76,29 +77,37 @@ public void channelRead( ChannelHandlerContext ctx, Object msg ) throws Exceptio
@Override
public void channelInactive( ChannelHandlerContext ctx ) throws Exception
{
close();
close( ctx );
}

@Override
public void handlerRemoved( ChannelHandlerContext ctx ) throws Exception
{
close();
close( ctx );
}

@Override
public void exceptionCaught( ChannelHandlerContext ctx, Throwable cause ) throws Exception
{
log.error( "Fatal error occurred when handling a client connection: " + cause.getMessage(), cause );
close();
close( ctx );
}

private void close()
private void close( ChannelHandlerContext ctx )
{
if(protocol != null)
if ( protocol != null )
{
// handshake was successful and protocol was initialized, so it needs to be closed now
// channel will be closed as part of the protocol's close procedure
protocol.close();
protocol = null;
}
else
{
// handshake did not happen or failed, protocol was not initialized, so we need to close the channel
// channel will be closed as part of the context's close procedure
ctx.close();
}
}

private void chooseProtocolVersion( ChannelHandlerContext ctx, ByteBuf buffer ) throws Exception
Expand Down
Expand Up @@ -24,39 +24,39 @@
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import org.junit.Test;

import java.util.HashMap;
import java.util.Map;
import java.util.function.BiFunction;

import org.neo4j.bolt.transport.BoltProtocol;
import org.neo4j.bolt.transport.SocketTransportHandler;
import org.neo4j.bolt.v1.runtime.SynchronousBoltWorker;
import org.neo4j.bolt.v1.runtime.BoltStateMachine;
import org.neo4j.bolt.v1.runtime.SynchronousBoltWorker;
import org.neo4j.bolt.v1.transport.BoltProtocolV1;
import org.neo4j.kernel.impl.logging.NullLogService;
import org.neo4j.logging.AssertableLogProvider;
import org.neo4j.logging.NullLogProvider;

import java.util.HashMap;
import java.util.Map;
import java.util.function.BiFunction;

import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.mockito.Mockito.*;
import static org.junit.Assert.assertSame;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.neo4j.bolt.transport.SocketTransportHandler.ProtocolChooser;
import static org.neo4j.logging.AssertableLogProvider.inLog;

public class SocketTransportHandlerTest
{
@Test
public void shouldCloseSessionOnChannelClose() throws Throwable
public void shouldCloseProtocolOnChannelInactive() throws Throwable
{
// Given
BoltStateMachine machine = mock(BoltStateMachine.class);
Channel ch = mock( Channel.class );
ChannelHandlerContext ctx = mock( ChannelHandlerContext.class );
when(ctx.channel()).thenReturn( ch );

when( ch.alloc() ).thenReturn( UnpooledByteBufAllocator.DEFAULT );
when( ctx.alloc() ).thenReturn( UnpooledByteBufAllocator.DEFAULT );
BoltStateMachine machine = mock( BoltStateMachine.class );
ChannelHandlerContext ctx = channelHandlerContextMock();

SocketTransportHandler handler = new SocketTransportHandler( protocolChooser( machine ), NullLogProvider.getInstance() );
SocketTransportHandler handler = newSocketTransportHandler( protocolChooser( machine ) );

// And Given a session has been established
handler.channelRead( ctx, handshake() );
Expand All @@ -69,17 +69,58 @@ public void shouldCloseSessionOnChannelClose() throws Throwable
}

@Test
public void logsAndClosesConnectionOnUnexpectedExceptions() throws Throwable
public void shouldCloseContextWhenProtocolNotInitializedOnChannelInactive() throws Throwable
{
// Given
ChannelHandlerContext context = mock( ChannelHandlerContext.class );
SocketTransportHandler handler = newSocketTransportHandler( mock( ProtocolChooser.class ) );

// When
handler.channelInactive( context );

// Then
verify( context ).close();
}

@Test
public void shouldCloseProtocolOnHandlerRemoved() throws Throwable
{
// Given
BoltStateMachine machine = mock( BoltStateMachine.class );
ChannelHandlerContext ctx = channelHandlerContextMock();

SocketTransportHandler handler = newSocketTransportHandler( protocolChooser( machine ) );

// And Given a session has been established
handler.channelRead( ctx, handshake() );

// When
handler.handlerRemoved( ctx );

// Then
verify( machine ).close();
}

@Test
public void shouldCloseContextWhenProtocolNotInitializedOnHandlerRemoved() throws Throwable
{
// Given
BoltStateMachine machine = mock(BoltStateMachine.class);
Channel ch = mock( Channel.class );
ChannelHandlerContext ctx = mock( ChannelHandlerContext.class );
when(ctx.channel()).thenReturn( ch );
ChannelHandlerContext context = mock( ChannelHandlerContext.class );
SocketTransportHandler handler = newSocketTransportHandler( mock( ProtocolChooser.class ) );

when( ch.alloc() ).thenReturn( UnpooledByteBufAllocator.DEFAULT );
when( ctx.alloc() ).thenReturn( UnpooledByteBufAllocator.DEFAULT );
// When
handler.handlerRemoved( context );

// Then
verify( context ).close();
}

@Test
public void logsAndClosesProtocolOnUnexpectedExceptions() throws Throwable
{
// Given
BoltStateMachine machine = mock( BoltStateMachine.class );
ChannelHandlerContext ctx = channelHandlerContextMock();
AssertableLogProvider logging = new AssertableLogProvider();

SocketTransportHandler handler = new SocketTransportHandler( protocolChooser( machine ), logging );
Expand All @@ -94,17 +135,72 @@ public void logsAndClosesConnectionOnUnexpectedExceptions() throws Throwable
// Then
verify( machine ).close();
logging.assertExactly( inLog( SocketTransportHandler.class )
.error( equalTo("Fatal error occurred when handling a client connection: Oh no!"), is(cause) ) );
.error( equalTo( "Fatal error occurred when handling a client connection: Oh no!" ), is( cause ) ) );
}

@Test
public void logsAndClosesContextWhenProtocolNotInitializedOnUnexpectedExceptions() throws Throwable
{
// Given
ChannelHandlerContext context = mock( ChannelHandlerContext.class );
AssertableLogProvider logging = new AssertableLogProvider();
SocketTransportHandler handler = new SocketTransportHandler( mock( ProtocolChooser.class ), logging );

// When
Throwable cause = new Throwable( "Oh no!" );
handler.exceptionCaught( context, cause );

// Then
verify( context ).close();
logging.assertExactly( inLog( SocketTransportHandler.class )
.error( equalTo( "Fatal error occurred when handling a client connection: Oh no!" ),
is( cause ) ) );
}

@Test
public void shouldInitializeProtocolOnFirstMessage() throws Exception
{
BoltStateMachine machine = mock( BoltStateMachine.class );
ProtocolChooser chooser = protocolChooser( machine );
ChannelHandlerContext context = channelHandlerContextMock();

SocketTransportHandler handler = new SocketTransportHandler( chooser, NullLogProvider.getInstance() );

handler.channelRead( context, handshake() );
BoltProtocol protocol1 = chooser.chosenProtocol();

handler.channelRead( context, handshake() );
BoltProtocol protocol2 = chooser.chosenProtocol();

assertSame( protocol1, protocol2 );
}

private static SocketTransportHandler newSocketTransportHandler( ProtocolChooser protocolChooser )
{
return new SocketTransportHandler( protocolChooser, NullLogProvider.getInstance() );
}

private static ChannelHandlerContext channelHandlerContextMock()
{
Channel channel = mock( Channel.class );
ChannelHandlerContext context = mock( ChannelHandlerContext.class );
when( context.channel() ).thenReturn( channel );

when( channel.alloc() ).thenReturn( UnpooledByteBufAllocator.DEFAULT );
when( context.alloc() ).thenReturn( UnpooledByteBufAllocator.DEFAULT );

return context;
}

private SocketTransportHandler.ProtocolChooser protocolChooser( final BoltStateMachine machine )
private ProtocolChooser protocolChooser( final BoltStateMachine machine )
{
Map<Long, BiFunction<Channel, Boolean, BoltProtocol>> availableVersions = new HashMap<>();
Map<Long,BiFunction<Channel,Boolean,BoltProtocol>> availableVersions = new HashMap<>();
availableVersions.put( (long) BoltProtocolV1.VERSION,
( channel, isSecure ) -> new BoltProtocolV1( new SynchronousBoltWorker( machine ), channel, NullLogService.getInstance() )
( channel, isSecure ) -> new BoltProtocolV1( new SynchronousBoltWorker( machine ), channel,
NullLogService.getInstance() )
);

return new SocketTransportHandler.ProtocolChooser( availableVersions, false, true );
return new ProtocolChooser( availableVersions, false, true );
}

private ByteBuf handshake()
Expand Down
4 changes: 4 additions & 0 deletions integrationtests/pom.xml
Expand Up @@ -171,6 +171,10 @@
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
</dependency>
<dependency>
<groupId>org.neo4j.driver</groupId>
<artifactId>neo4j-java-driver</artifactId>
Expand Down

0 comments on commit 9b0c8d9

Please sign in to comment.