diff --git a/enterprise/causal-clustering/src/main/java/org/neo4j/causalclustering/catchup/CatchUpClient.java b/enterprise/causal-clustering/src/main/java/org/neo4j/causalclustering/catchup/CatchUpClient.java index 4e517bae3d971..04364bed89576 100644 --- a/enterprise/causal-clustering/src/main/java/org/neo4j/causalclustering/catchup/CatchUpClient.java +++ b/enterprise/causal-clustering/src/main/java/org/neo4j/causalclustering/catchup/CatchUpClient.java @@ -143,8 +143,7 @@ void send( CatchUpRequest request ) throws ConnectException throw new ConnectException( "Channel is not connected" ); } nettyChannel.write( request.messageType() ); - nettyChannel.closeFuture().addListener( (ChannelFutureListener) future -> handler.onClose() ); - nettyChannel.writeAndFlush( request ).addListener( ChannelFutureListener.CLOSE_ON_FAILURE ); + nettyChannel.writeAndFlush( request ); } Optional millisSinceLastResponse() @@ -163,6 +162,8 @@ public void connect() throws Exception { ChannelFuture channelFuture = bootstrap.connect( destination.socketAddress() ); nettyChannel = channelFuture.sync().channel(); + nettyChannel.closeFuture().addListener( (ChannelFutureListener) future -> handler.onClose() ); + } @Override diff --git a/enterprise/causal-clustering/src/test/java/org/neo4j/causalclustering/catchup/CatchUpClientIT.java b/enterprise/causal-clustering/src/test/java/org/neo4j/causalclustering/catchup/CatchUpClientIT.java index 889588e6ce7c8..09372eaebba95 100644 --- a/enterprise/causal-clustering/src/test/java/org/neo4j/causalclustering/catchup/CatchUpClientIT.java +++ b/enterprise/causal-clustering/src/test/java/org/neo4j/causalclustering/catchup/CatchUpClientIT.java @@ -23,9 +23,8 @@ import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInitializer; -import io.netty.channel.SimpleChannelInboundHandler; import io.netty.channel.socket.SocketChannel; -import io.netty.channel.socket.nio.NioSocketChannel; +import io.netty.handler.codec.ByteToMessageDecoder; import io.netty.handler.codec.MessageToByteEncoder; import org.junit.After; import org.junit.Before; @@ -33,7 +32,9 @@ import java.nio.channels.ClosedChannelException; import java.time.Clock; +import java.util.List; import java.util.concurrent.ExecutionException; +import java.util.concurrent.atomic.AtomicBoolean; import org.neo4j.causalclustering.catchup.storecopy.GetStoreIdRequest; import org.neo4j.causalclustering.net.Server; @@ -45,6 +46,7 @@ import org.neo4j.ports.allocation.PortAuthority; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; public class CatchUpClientIT @@ -72,9 +74,10 @@ public void shouldCloseHandlerIfChannelIsClosedInClient() throws LifecycleExcept String hostname = "localhost"; int port = PortAuthority.allocatePort(); ListenSocketAddress listenSocketAddress = new ListenSocketAddress( hostname, port ); + AtomicBoolean wasClosedByClient = new AtomicBoolean( false ); Server emptyServer = catchupServer( listenSocketAddress ); - CatchUpClient closingClient = closingChannelCatchupClient(); + CatchUpClient closingClient = closingChannelCatchupClient( wasClosedByClient ); lifeSupport.add( emptyServer ); lifeSupport.add( closingClient ); @@ -85,6 +88,7 @@ public void shouldCloseHandlerIfChannelIsClosedInClient() throws LifecycleExcept // then assertClosedChannelException( hostname, port, closingClient ); + assertTrue( wasClosedByClient.get() ); } @Test @@ -94,9 +98,10 @@ public void shouldCloseHandlerIfChannelIsClosedOnServer() String hostname = "localhost"; int port = PortAuthority.allocatePort(); ListenSocketAddress listenSocketAddress = new ListenSocketAddress( hostname, port ); + AtomicBoolean wasClosedByServer = new AtomicBoolean( false ); - Server closingChannelServer = closingChannelCatchupServer( listenSocketAddress ); - CatchUpClient emptyClient = catchupClient(); + Server closingChannelServer = closingChannelCatchupServer( listenSocketAddress, wasClosedByServer ); + CatchUpClient emptyClient = emptyClient(); lifeSupport.add( closingChannelServer ); lifeSupport.add( emptyClient ); @@ -107,6 +112,19 @@ public void shouldCloseHandlerIfChannelIsClosedOnServer() // then assertClosedChannelException( hostname, port, emptyClient ); + assertTrue( wasClosedByServer.get() ); + } + + private CatchUpClient emptyClient() + { + return catchupClient( new MessageToByteEncoder() + { + @Override + protected void encode( ChannelHandlerContext channelHandlerContext, GetStoreIdRequest getStoreIdRequest, ByteBuf byteBuf ) + { + byteBuf.writeByte( (byte) 1 ); + } + } ); } private void assertClosedChannelException( String hostname, int port, CatchUpClient closingClient ) @@ -130,25 +148,27 @@ private CatchUpResponseAdaptor neverCompletingAdaptor() return new CatchUpResponseAdaptor<>(); } - private CatchUpClient closingChannelCatchupClient() + private CatchUpClient closingChannelCatchupClient( AtomicBoolean wasClosedByClient ) { return catchupClient( new MessageToByteEncoder() { @Override protected void encode( ChannelHandlerContext ctx, Object msg, ByteBuf out ) { + wasClosedByClient.set( true ); ctx.channel().close(); } } ); } - private Server closingChannelCatchupServer( ListenSocketAddress listenSocketAddress ) + private Server closingChannelCatchupServer( ListenSocketAddress listenSocketAddress, AtomicBoolean wasClosedByServer ) { - return catchupServer( listenSocketAddress, new SimpleChannelInboundHandler() + return catchupServer( listenSocketAddress, new ByteToMessageDecoder() { @Override - protected void channelRead0( ChannelHandlerContext ctx, NioSocketChannel msg ) + protected void decode( ChannelHandlerContext ctx, ByteBuf byteBuf, List list ) { + wasClosedByServer.set( true ); ctx.channel().close(); } } );