Skip to content

Commit

Permalink
Add dispatcher that delegates to the correct decoder in CatchupServer
Browse files Browse the repository at this point in the history
This will avoid unnecessary buffer copies.
  • Loading branch information
davidegrohmann committed Aug 16, 2016
1 parent 739e95a commit 39a10da
Show file tree
Hide file tree
Showing 10 changed files with 169 additions and 94 deletions.
Expand Up @@ -20,7 +20,9 @@
package org.neo4j.coreedge.catchup; package org.neo4j.coreedge.catchup;


import io.netty.bootstrap.ServerBootstrap; import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.channel.Channel; import io.netty.channel.Channel;
import io.netty.channel.ChannelInboundHandler;
import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelPipeline; import io.netty.channel.ChannelPipeline;
import io.netty.channel.EventLoopGroup; import io.netty.channel.EventLoopGroup;
Expand All @@ -29,11 +31,13 @@
import io.netty.channel.socket.nio.NioServerSocketChannel; import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.handler.codec.LengthFieldBasedFrameDecoder; import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
import io.netty.handler.codec.LengthFieldPrepender; import io.netty.handler.codec.LengthFieldPrepender;
import io.netty.handler.codec.MessageToMessageDecoder;
import io.netty.handler.stream.ChunkedWriteHandler; import io.netty.handler.stream.ChunkedWriteHandler;


import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.function.Supplier; import java.util.function.Supplier;


import org.neo4j.coreedge.catchup.CatchupServerProtocol.NextMessage;
import org.neo4j.coreedge.catchup.storecopy.FileHeaderEncoder; import org.neo4j.coreedge.catchup.storecopy.FileHeaderEncoder;
import org.neo4j.coreedge.catchup.storecopy.GetStoreIdRequestDecoder; import org.neo4j.coreedge.catchup.storecopy.GetStoreIdRequestDecoder;
import org.neo4j.coreedge.catchup.storecopy.GetStoreIdRequestHandler; import org.neo4j.coreedge.catchup.storecopy.GetStoreIdRequestHandler;
Expand Down Expand Up @@ -103,11 +107,6 @@ public CatchupServer( LogProvider logProvider,


@Override @Override
public synchronized void start() throws Throwable public synchronized void start() throws Throwable
{
startNettyServer();
}

private void startNettyServer()
{ {
workerGroup = new NioEventLoopGroup( 0, threadFactory ); workerGroup = new NioEventLoopGroup( 0, threadFactory );


Expand Down Expand Up @@ -136,10 +135,7 @@ protected void initChannel( SocketChannel ch ) throws Exception


pipeline.addLast( new ServerMessageTypeHandler( protocol, logProvider ) ); pipeline.addLast( new ServerMessageTypeHandler( protocol, logProvider ) );


pipeline.addLast( new TxPullRequestDecoder( protocol ) ); pipeline.addLast( decoders( protocol ) );
pipeline.addLast( new GetStoreRequestDecoder( protocol ) );
pipeline.addLast( new GetStoreIdRequestDecoder( protocol ) );
pipeline.addLast( new CoreSnapshotRequestDecoder( protocol ) );


pipeline.addLast( new TxPullRequestHandler( protocol, storeIdSupplier, pipeline.addLast( new TxPullRequestHandler( protocol, storeIdSupplier,
transactionIdStoreSupplier, logicalTransactionStoreSupplier, transactionIdStoreSupplier, logicalTransactionStoreSupplier,
Expand All @@ -156,6 +152,16 @@ protected void initChannel( SocketChannel ch ) throws Exception
channel = bootstrap.bind().syncUninterruptibly().channel(); channel = bootstrap.bind().syncUninterruptibly().channel();
} }


private ChannelInboundHandler decoders( CatchupServerProtocol protocol )
{
RequestDecoderDispatcher decoderDispatcher = new RequestDecoderDispatcher( protocol, logProvider );
decoderDispatcher.register( NextMessage.TX_PULL, new TxPullRequestDecoder() );
decoderDispatcher.register( NextMessage.GET_STORE, new GetStoreRequestDecoder() );
decoderDispatcher.register( NextMessage.GET_STORE_ID, new GetStoreIdRequestDecoder() );
decoderDispatcher.register( NextMessage.GET_RAFT_STATE, new CoreSnapshotRequestDecoder() );
return decoderDispatcher;
}

@Override @Override
public synchronized void stop() throws Throwable public synchronized void stop() throws Throwable
{ {
Expand Down
Expand Up @@ -28,9 +28,9 @@ public void expect( NextMessage nextMessage )
this.nextMessage = nextMessage; this.nextMessage = nextMessage;
} }


public boolean isExpecting( NextMessage message ) NextMessage expecting()
{ {
return this.nextMessage == message; return nextMessage;
} }


public enum NextMessage public enum NextMessage
Expand Down
@@ -0,0 +1,64 @@
/*
* Copyright (c) 2002-2016 "Neo Technology,"
* Network Engine for Objects in Lund AB [http://neotechnology.com]
*
* This file is part of Neo4j.
*
* Neo4j is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as
* published by the Free Software Foundation, either version 3 of the
* License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.neo4j.coreedge.catchup;

import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandler;
import io.netty.channel.ChannelInboundHandlerAdapter;

import java.util.HashMap;
import java.util.Map;

import org.neo4j.coreedge.catchup.CatchupServerProtocol.NextMessage;
import org.neo4j.logging.Log;
import org.neo4j.logging.LogProvider;

public class RequestDecoderDispatcher extends ChannelInboundHandlerAdapter
{
private final Map<NextMessage, ChannelInboundHandler> decoders = new HashMap<>();
private final CatchupServerProtocol protocol;
private final Log log;

public RequestDecoderDispatcher( CatchupServerProtocol protocol, LogProvider logProvider )
{
this.protocol = protocol;
this.log = logProvider.getLog( getClass() );
}

@Override
public void channelRead( ChannelHandlerContext ctx, Object msg ) throws Exception
{
NextMessage expecting = protocol.expecting();
ChannelInboundHandler delegate = decoders.get( expecting );
if ( delegate == null )
{
log.warn( "Unknown message %s", expecting );
return;
}

delegate.channelRead( ctx, msg );
}

public void register( NextMessage type, ChannelInboundHandler decoder )
{
assert !decoders.containsKey( type ) : "registering twice a decoder for the same type?";
decoders.put( type, decoder );
}
}
Expand Up @@ -29,12 +29,12 @@


import static org.neo4j.coreedge.catchup.CatchupServerProtocol.NextMessage; import static org.neo4j.coreedge.catchup.CatchupServerProtocol.NextMessage;


public class ServerMessageTypeHandler extends ChannelInboundHandlerAdapter class ServerMessageTypeHandler extends ChannelInboundHandlerAdapter
{ {
private final Log log; private final Log log;
private final CatchupServerProtocol protocol; private final CatchupServerProtocol protocol;


public ServerMessageTypeHandler( CatchupServerProtocol protocol, LogProvider logProvider ) ServerMessageTypeHandler( CatchupServerProtocol protocol, LogProvider logProvider )
{ {
this.protocol = protocol; this.protocol = protocol;
this.log = logProvider.getLog( getClass() ); this.log = logProvider.getLog( getClass() );
Expand All @@ -43,7 +43,7 @@ public ServerMessageTypeHandler( CatchupServerProtocol protocol, LogProvider log
@Override @Override
public void channelRead( ChannelHandlerContext ctx, Object msg ) throws Exception public void channelRead( ChannelHandlerContext ctx, Object msg ) throws Exception
{ {
if ( protocol.isExpecting( NextMessage.MESSAGE_TYPE ) ) if ( protocol.expecting().equals( NextMessage.MESSAGE_TYPE ) )
{ {
RequestMessageType requestMessageType = RequestMessageType.from( ((ByteBuf) msg).readByte() ); RequestMessageType requestMessageType = RequestMessageType.from( ((ByteBuf) msg).readByte() );


Expand Down
Expand Up @@ -30,23 +30,9 @@


public class GetStoreIdRequestDecoder extends MessageToMessageDecoder<ByteBuf> public class GetStoreIdRequestDecoder extends MessageToMessageDecoder<ByteBuf>
{ {
private final CatchupServerProtocol protocol;

public GetStoreIdRequestDecoder( CatchupServerProtocol protocol )
{
this.protocol = protocol;
}

@Override @Override
protected void decode( ChannelHandlerContext ctx, ByteBuf msg, List<Object> out ) throws Exception protected void decode( ChannelHandlerContext ctx, ByteBuf msg, List<Object> out ) throws Exception
{ {
if ( protocol.isExpecting( CatchupServerProtocol.NextMessage.GET_STORE_ID ) ) out.add( new GetStoreIdRequest() );
{
out.add( new GetStoreIdRequest() );
}
else
{
out.add( Unpooled.copiedBuffer( msg ) );
}
} }
} }
Expand Up @@ -19,35 +19,17 @@
*/ */
package org.neo4j.coreedge.catchup.storecopy; package org.neo4j.coreedge.catchup.storecopy;


import java.util.List;

import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.MessageToMessageDecoder; import io.netty.handler.codec.MessageToMessageDecoder;


import org.neo4j.coreedge.catchup.CatchupServerProtocol; import java.util.List;


public class GetStoreRequestDecoder extends MessageToMessageDecoder<ByteBuf> public class GetStoreRequestDecoder extends MessageToMessageDecoder<ByteBuf>
{ {
private final CatchupServerProtocol protocol;

public GetStoreRequestDecoder( CatchupServerProtocol protocol )
{
this.protocol = protocol;
}

@Override @Override
protected void decode( ChannelHandlerContext ctx, ByteBuf msg, List<Object> out ) throws Exception protected void decode( ChannelHandlerContext ctx, ByteBuf msg, List<Object> out ) throws Exception
{ {
if ( protocol.isExpecting( CatchupServerProtocol.NextMessage.GET_STORE ) ) out.add( new GetStoreRequest() );
{
out.add( new GetStoreRequest() );
}
else
{
out.add( Unpooled.copiedBuffer( msg ) );
}

} }
} }
Expand Up @@ -19,41 +19,23 @@
*/ */
package org.neo4j.coreedge.catchup.tx; package org.neo4j.coreedge.catchup.tx;


import java.util.List;

import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.MessageToMessageDecoder; import io.netty.handler.codec.MessageToMessageDecoder;


import org.neo4j.coreedge.catchup.CatchupServerProtocol; import java.util.List;

import org.neo4j.coreedge.identity.StoreId; import org.neo4j.coreedge.identity.StoreId;
import org.neo4j.coreedge.messaging.NetworkReadableClosableChannelNetty4; import org.neo4j.coreedge.messaging.NetworkReadableClosableChannelNetty4;
import org.neo4j.coreedge.messaging.marsalling.storeid.StoreIdMarshal; import org.neo4j.coreedge.messaging.marsalling.storeid.StoreIdMarshal;


import static org.neo4j.coreedge.catchup.CatchupServerProtocol.NextMessage.TX_PULL;

public class TxPullRequestDecoder extends MessageToMessageDecoder<ByteBuf> public class TxPullRequestDecoder extends MessageToMessageDecoder<ByteBuf>
{ {
private final CatchupServerProtocol protocol;

public TxPullRequestDecoder( CatchupServerProtocol protocol )
{
this.protocol = protocol;
}

@Override @Override
protected void decode( ChannelHandlerContext ctx, ByteBuf msg, List<Object> out ) throws Exception protected void decode( ChannelHandlerContext ctx, ByteBuf msg, List<Object> out ) throws Exception
{ {
if ( protocol.isExpecting( TX_PULL ) ) long txId = msg.readLong();
{ StoreId storeId = StoreIdMarshal.unmarshal( new NetworkReadableClosableChannelNetty4( msg ) );
long txId = msg.readLong(); out.add( new TxPullRequest( txId, storeId ) );
StoreId storeId = StoreIdMarshal.unmarshal( new NetworkReadableClosableChannelNetty4( msg ) );
out.add( new TxPullRequest( txId, storeId ) );
}
else
{
out.add( Unpooled.copiedBuffer( msg ) );
}
} }
} }
Expand Up @@ -30,24 +30,9 @@


public class CoreSnapshotRequestDecoder extends MessageToMessageDecoder<ByteBuf> public class CoreSnapshotRequestDecoder extends MessageToMessageDecoder<ByteBuf>
{ {
private final CatchupServerProtocol protocol;

public CoreSnapshotRequestDecoder( CatchupServerProtocol protocol )
{
this.protocol = protocol;
}

@Override @Override
protected void decode( ChannelHandlerContext ctx, ByteBuf msg, List<Object> out ) throws Exception protected void decode( ChannelHandlerContext ctx, ByteBuf msg, List<Object> out ) throws Exception
{ {
if ( protocol.isExpecting( CatchupServerProtocol.NextMessage.GET_RAFT_STATE ) ) out.add( new CoreSnapshotRequest() );
{
out.add( new CoreSnapshotRequest() );
}
else
{
out.add( Unpooled.copiedBuffer( msg ) );
}

} }
} }
@@ -0,0 +1,75 @@
/*
* Copyright (c) 2002-2016 "Neo Technology,"
* Network Engine for Objects in Lund AB [http://neotechnology.com]
*
* This file is part of Neo4j.
*
* Neo4j is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as
* published by the Free Software Foundation, either version 3 of the
* License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.neo4j.coreedge.catchup;

import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandler;
import org.junit.Test;

import org.neo4j.logging.AssertableLogProvider;

import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.neo4j.coreedge.catchup.CatchupServerProtocol.NextMessage.TX_PULL;
import static org.neo4j.logging.AssertableLogProvider.inLog;

public class RequestDecoderDispatcherTest
{
private final CatchupServerProtocol protocol = new CatchupServerProtocol();
private final AssertableLogProvider logProvider = new AssertableLogProvider();

@Test
public void shouldDispatchToRegisteredDecoder() throws Exception
{
// given
RequestDecoderDispatcher dispatcher = new RequestDecoderDispatcher( protocol, logProvider );
protocol.expect( TX_PULL );
ChannelInboundHandler delegate = mock( ChannelInboundHandler.class );
dispatcher.register( TX_PULL, delegate );

ChannelHandlerContext ctx = mock( ChannelHandlerContext.class );
Object msg = new Object();

// when
dispatcher.channelRead( ctx, msg );

// then
verify( delegate ).channelRead( ctx, msg );
verifyNoMoreInteractions( delegate );
}

@Test
public void shouldLogAWarningIfThereIsNoDecoderForTheMessageType() throws Exception
{
// given
RequestDecoderDispatcher dispatcher = new RequestDecoderDispatcher( protocol, logProvider );
protocol.expect( TX_PULL );

// when
dispatcher.channelRead( mock( ChannelHandlerContext.class ), new Object() );

// then
AssertableLogProvider.LogMatcher matcher =
inLog( RequestDecoderDispatcher.class ).warn( "Unknown message %s", TX_PULL );

logProvider.assertExactly( matcher );
}
}
Expand Up @@ -35,13 +35,8 @@ public class TxPullRequestEncodeDecodeTest
@Test @Test
public void shouldEncodeAndDecodePullRequestMessage() public void shouldEncodeAndDecodePullRequestMessage()
{ {
CatchupServerProtocol protocol = new CatchupServerProtocol();
protocol.expect( NextMessage.TX_PULL );

EmbeddedChannel channel = new EmbeddedChannel( new TxPullRequestEncoder(),
new TxPullRequestDecoder( protocol ) );

// given // given
EmbeddedChannel channel = new EmbeddedChannel( new TxPullRequestEncoder(), new TxPullRequestDecoder() );
final long arbitraryId = 23; final long arbitraryId = 23;
TxPullRequest sent = new TxPullRequest( arbitraryId, new StoreId( 1, 2, 3, 4 ) ); TxPullRequest sent = new TxPullRequest( arbitraryId, new StoreId( 1, 2, 3, 4 ) );


Expand Down

0 comments on commit 39a10da

Please sign in to comment.