Skip to content

Commit

Permalink
Change order. Message header is sent first.
Browse files Browse the repository at this point in the history
  • Loading branch information
RagnarW authored and martinfurmanski committed Jun 11, 2018
1 parent ecccdaf commit 4b33f67
Show file tree
Hide file tree
Showing 15 changed files with 282 additions and 76 deletions.
Expand Up @@ -54,7 +54,7 @@ class BoltStateMachineSPI implements BoltStateMachine.SPI
this.connectionTracker = connectionTracker;
this.authentication = authentication;
this.transactionSpi = transactionStateMachineSPI;
this.version = "Neo4j/" + "3.5.0";
this.version = "Neo4j/" + Version.getNeo4jVersion();
}

@Override
Expand Down
Expand Up @@ -30,8 +30,8 @@
import java.util.List;
import java.util.stream.Collectors;

import org.neo4j.causalclustering.messaging.marshalling.v2.decoding.ContentTypeDispatcher;
import org.neo4j.causalclustering.messaging.marshalling.v2.ContentTypeProtocol;
import org.neo4j.causalclustering.messaging.marshalling.v2.decoding.ContentTypeDispatcher;
import org.neo4j.causalclustering.messaging.marshalling.v2.decoding.DecodingDispatcher;
import org.neo4j.causalclustering.messaging.marshalling.v2.decoding.RaftMessageComposer;
import org.neo4j.causalclustering.messaging.marshalling.v2.decoding.ReplicatedContentDecoder;
Expand Down Expand Up @@ -76,12 +76,14 @@ public void install( Channel channel ) throws Exception
{

ContentTypeProtocol contentTypeProtocol = new ContentTypeProtocol();
DecodingDispatcher decodingDispatcher = new DecodingDispatcher( contentTypeProtocol, logProvider );
pipelineBuilderFactory
.server( channel, log )
.modify( modifiers )
.addFraming()
.onClose( decodingDispatcher::close )
.add( "raft_content_type_dispatcher", new ContentTypeDispatcher( contentTypeProtocol ) )
.add( "raft_component_decoder", new DecodingDispatcher( contentTypeProtocol, logProvider ) )
.add( "raft_component_decoder", decodingDispatcher )
.add( "raft_content_decoder", new ReplicatedContentDecoder( contentTypeProtocol ) )
.add( "raft_message_composer", new RaftMessageComposer( Clock.systemUTC() ) )
.add( "raft_handler", raftMessageHandler )
Expand Down
Expand Up @@ -24,7 +24,7 @@

public enum ContentType
{
MessageType( (byte) 0 ),
ContentType( (byte) 0 ),
ReplicatedContent( (byte) 1 ),
RaftLogEntries( (byte) 2 ),
Message( (byte) 3 );
Expand Down
Expand Up @@ -28,6 +28,6 @@ public class ContentTypeProtocol extends Protocol<ContentType>
{
public ContentTypeProtocol()
{
super( ContentType.MessageType );
super( ContentType.ContentType );
}
}
Expand Up @@ -43,7 +43,7 @@ public ContentTypeDispatcher( Protocol<ContentType> contentTypeProtocol )
@Override
protected void decode( ChannelHandlerContext ctx, ByteBuf in, List<Object> out )
{
if ( contentTypeProtocol.isExpecting( ContentType.MessageType ) )
if ( contentTypeProtocol.isExpecting( ContentType.ContentType ) )
{
byte b = in.readByte();
ContentType contentType = getContentType( b );
Expand Down
Expand Up @@ -33,12 +33,14 @@
import org.neo4j.causalclustering.messaging.marshalling.v2.ContentType;
import org.neo4j.logging.LogProvider;

public class DecodingDispatcher extends RequestDecoderDispatcher<ContentType>
public class DecodingDispatcher extends RequestDecoderDispatcher<ContentType> implements AutoCloseable
{
private final ReplicatedContentChunkDecoder decoder;

public DecodingDispatcher( Protocol<ContentType> protocol, LogProvider logProvider )
{
super( protocol, logProvider );
register( ContentType.MessageType, new ByteToMessageDecoder()
register( ContentType.ContentType, new ByteToMessageDecoder()
{
@Override
protected void decode( ChannelHandlerContext ctx, ByteBuf in, List<Object> out )
Expand All @@ -49,8 +51,15 @@ protected void decode( ChannelHandlerContext ctx, ByteBuf in, List<Object> out )
}
}
} );
register( ContentType.RaftLogEntries, new RaftLogEntryTermDecoder( protocol ) );
register( ContentType.ReplicatedContent, new ReplicatedContentChunkDecoder( ) );
register( ContentType.RaftLogEntries, new RaftLogEntryTermsDecoder( protocol ) );
decoder = new ReplicatedContentChunkDecoder();
register( ContentType.ReplicatedContent, decoder );
register( ContentType.Message, new RaftMessageDecoder( protocol ) );
}

@Override
public void close()
{
decoder.close();
}
}
Expand Up @@ -31,33 +31,38 @@
import org.neo4j.causalclustering.catchup.Protocol;
import org.neo4j.causalclustering.messaging.marshalling.v2.ContentType;

class RaftLogEntryTermDecoder extends ByteToMessageDecoder
class RaftLogEntryTermsDecoder extends ByteToMessageDecoder
{
private final Protocol<ContentType> protocol;

RaftLogEntryTermDecoder( Protocol<ContentType> protocol )
RaftLogEntryTermsDecoder( Protocol<ContentType> protocol )
{
this.protocol = protocol;
}

@Override
protected void decode( ChannelHandlerContext ctx, ByteBuf in, List<Object> out )
{
long l = in.readLong();
out.add( new RaftLogEntryTerm( l ) );
protocol.expect( ContentType.MessageType );
int size = in.readInt();
long[] terms = new long[size];
for ( int i = 0; i < size; i++ )
{
terms[i] = in.readLong();
}
out.add( new RaftLogEntryTerms( terms ) );
protocol.expect( ContentType.ContentType );
}

class RaftLogEntryTerm
class RaftLogEntryTerms
{
private final long term;
private final long[] term;

RaftLogEntryTerm( long term )
RaftLogEntryTerms( long[] term )
{
this.term = term;
}

public long term()
public long[] terms()
{
return term;
}
Expand Down
Expand Up @@ -26,20 +26,19 @@
import io.netty.handler.codec.MessageToMessageDecoder;

import java.time.Clock;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Queue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.function.Supplier;

import org.neo4j.causalclustering.core.consensus.RaftMessages;
import org.neo4j.causalclustering.core.consensus.log.RaftLogEntry;
import org.neo4j.causalclustering.core.replication.ReplicatedContent;

public class RaftMessageComposer extends MessageToMessageDecoder<Object>
{
private final Queue<ReplicatedContent> replicatedContents = new LinkedBlockingQueue<>();
private final RaftLogEntries raftLogEntries = new RaftLogEntries();
private final Queue<Long> raftLogEntries = new LinkedBlockingQueue<>();
private RaftMessageDecoder.RaftMessageCreator messageCreator;
private final Clock clock;

public RaftMessageComposer( Clock clock )
Expand All @@ -54,38 +53,60 @@ protected void decode( ChannelHandlerContext ctx, Object msg, List<Object> out )
{
replicatedContents.add( (ReplicatedContent) msg );
}
else if ( msg instanceof RaftLogEntryTermDecoder.RaftLogEntryTerm )
else if ( msg instanceof RaftLogEntryTermsDecoder.RaftLogEntryTerms )
{
long term = ((RaftLogEntryTermDecoder.RaftLogEntryTerm) msg).term();
raftLogEntries.add( new RaftLogEntry( term, replicatedContents.poll() ) );
for ( long term : ((RaftLogEntryTermsDecoder.RaftLogEntryTerms) msg).terms() )
{
raftLogEntries.add( term );
}
}
else if ( msg instanceof RaftMessageDecoder.RaftMessageCreator )
{
RaftMessageDecoder.RaftMessageCreator messageCreator = (RaftMessageDecoder.RaftMessageCreator) msg;
out.add( RaftMessages.ReceivedInstantClusterIdAwareMessage.of( clock.instant(), messageCreator.clusterId(),
messageCreator.result().apply( raftLogEntries, replicatedContents::poll ) ) );
if ( messageCreator != null )
{
throw new IllegalStateException( "Received raft message header. Pipeline already contains message header waiting to build." );
}
messageCreator = (RaftMessageDecoder.RaftMessageCreator) msg;
}
else
{
throw new IllegalStateException( "Unexpected object in the pipeline: " + msg );
}
if ( messageCreator != null )
{
RaftMessages.ClusterIdAwareMessage clusterIdAwareMessage = messageCreator.maybeCompose( clock, raftLogEntries, replicatedContents );
if ( clusterIdAwareMessage != null )
{
clear( clusterIdAwareMessage.toString() );
out.add( clusterIdAwareMessage );
}
}
}

private class RaftLogEntries implements Supplier<RaftLogEntry[]>
private void clear( String messageDescription )
{
private final ArrayList<RaftLogEntry> raftLogEntries = new ArrayList<>();

void add( RaftLogEntry raftLogEntry )
messageCreator = null;
if ( !replicatedContents.isEmpty() || !raftLogEntries.isEmpty() )
{
raftLogEntries.add( raftLogEntry );
throw new IllegalStateException( String.format(
"Message [%s] was composed without using all resources in the pipeline. " +
"Pipeline still contains Replicated contents[%s] and RaftLogEntries [%s]",
messageDescription, stringify( replicatedContents ), stringify( raftLogEntries ) ) );
}
}

@Override
public RaftLogEntry[] get()
private String stringify( Iterable<?> objects )
{
StringBuilder stringBuilder = new StringBuilder();
Iterator<?> iterator = objects.iterator();
while ( iterator.hasNext() )
{
RaftLogEntry[] array = this.raftLogEntries.toArray( RaftLogEntry.empty );
raftLogEntries.clear();
return array;
stringBuilder.append( iterator.next() );
if ( iterator.hasNext() )
{
stringBuilder.append( ", " );
}
}
return stringBuilder.toString();
}
}
Expand Up @@ -27,9 +27,10 @@
import io.netty.handler.codec.ByteToMessageDecoder;

import java.io.IOException;
import java.time.Clock;
import java.util.List;
import java.util.Queue;
import java.util.function.BiFunction;
import java.util.function.Supplier;

import org.neo4j.causalclustering.catchup.Protocol;
import org.neo4j.causalclustering.core.consensus.RaftMessages;
Expand Down Expand Up @@ -73,7 +74,7 @@ public void decode( ChannelHandlerContext ctx, ByteBuf buffer, List<Object> list
RaftMessages.Type messageType = values[messageTypeWire];

MemberId from = retrieveMember( channel );
BiFunction<Supplier<RaftLogEntry[]>,Supplier<ReplicatedContent>,RaftMessages.BaseRaftMessage> result;
BiFunction<Queue<Long>,Queue<ReplicatedContent>,RaftMessages.BaseRaftMessage> result;

if ( messageType.equals( VOTE_REQUEST ) )
{
Expand Down Expand Up @@ -116,8 +117,29 @@ else if ( messageType.equals( APPEND_ENTRIES_REQUEST ) )
long prevLogIndex = channel.getLong();
long prevLogTerm = channel.getLong();
long leaderCommit = channel.getLong();

result = ( rle, rc ) -> new RaftMessages.AppendEntries.Request( from, term, prevLogIndex, prevLogTerm, rle.get(), leaderCommit );
int raftLogEntries = channel.getInt();

result = ( rle, rc ) ->
{
if ( rle.size() < raftLogEntries || rc.size() < raftLogEntries )
{
return null;
}
else
{
RaftLogEntry[] entries = new RaftLogEntry[raftLogEntries];
for ( int i = 0; i < raftLogEntries; i++ )
{
Long poll = rle.poll();
if ( poll == null )
{
throw new IllegalArgumentException( "Term cannot be null" );
}
entries[i] = new RaftLogEntry( poll, rc.poll() );
}
return new RaftMessages.AppendEntries.Request( from, term, prevLogIndex, prevLogTerm, entries, leaderCommit );
}
};
}
else if ( messageType.equals( APPEND_ENTRIES_RESPONSE ) )
{
Expand All @@ -130,7 +152,17 @@ else if ( messageType.equals( APPEND_ENTRIES_RESPONSE ) )
}
else if ( messageType.equals( NEW_ENTRY_REQUEST ) )
{
result = ( rle, rc ) -> new RaftMessages.NewEntry.Request( from, rc.get() );
result = ( rle, rc ) ->
{
if ( rc.isEmpty() )
{
return null;
}
else
{
return new RaftMessages.NewEntry.Request( from, rc.poll() );
}
};
}
else if ( messageType.equals( HEARTBEAT ) )
{
Expand All @@ -157,33 +189,36 @@ else if ( messageType.equals( LOG_COMPACTION_INFO ) )
}

list.add( new RaftMessageCreator( result, clusterId ) );
protocol.expect( ContentType.MessageType );
protocol.expect( ContentType.ContentType );
}

private BiFunction<Supplier<RaftLogEntry[]>,Supplier<ReplicatedContent>,RaftMessages.BaseRaftMessage> noContent( RaftMessages.BaseRaftMessage message )
private BiFunction<Queue<Long>,Queue<ReplicatedContent>,RaftMessages.BaseRaftMessage> noContent( RaftMessages.BaseRaftMessage message )
{
return ( rle, rc ) -> message;
}

class RaftMessageCreator
static class RaftMessageCreator
{
private final BiFunction<Supplier<RaftLogEntry[]>,Supplier<ReplicatedContent>,RaftMessages.BaseRaftMessage> result;
private final BiFunction<Queue<Long>,Queue<ReplicatedContent>,RaftMessages.BaseRaftMessage> result;
private final ClusterId clusterId;

RaftMessageCreator( BiFunction<Supplier<RaftLogEntry[]>,Supplier<ReplicatedContent>,RaftMessages.BaseRaftMessage> result, ClusterId clusterId )
RaftMessageCreator( BiFunction<Queue<Long>,Queue<ReplicatedContent>,RaftMessages.BaseRaftMessage> result, ClusterId clusterId )
{
this.result = result;
this.clusterId = clusterId;
}

public ClusterId clusterId()
{
return clusterId;
}

public BiFunction<Supplier<RaftLogEntry[]>,Supplier<ReplicatedContent>,RaftMessages.BaseRaftMessage> result()
RaftMessages.ClusterIdAwareMessage maybeCompose( Clock clock, Queue<Long> logEntryTerms, Queue<ReplicatedContent> replicatedContents )
{
return result;
RaftMessages.BaseRaftMessage apply = result.apply( logEntryTerms, replicatedContents );
if ( apply != null )
{
return RaftMessages.ReceivedInstantClusterIdAwareMessage.of( clock.instant(), clusterId, apply );
}
else
{
return null;
}
}
}

Expand Down

0 comments on commit 4b33f67

Please sign in to comment.