diff --git a/enterprise/causal-clustering/src/main/java/org/neo4j/causalclustering/messaging/marshalling/v2/decoding/RaftMessageComposer.java b/enterprise/causal-clustering/src/main/java/org/neo4j/causalclustering/messaging/marshalling/v2/decoding/RaftMessageComposer.java index aece934fa13bb..4527f846913c1 100644 --- a/enterprise/causal-clustering/src/main/java/org/neo4j/causalclustering/messaging/marshalling/v2/decoding/RaftMessageComposer.java +++ b/enterprise/causal-clustering/src/main/java/org/neo4j/causalclustering/messaging/marshalling/v2/decoding/RaftMessageComposer.java @@ -29,6 +29,7 @@ import java.util.Iterator; import java.util.LinkedList; import java.util.List; +import java.util.Optional; import java.util.Queue; import org.neo4j.causalclustering.core.consensus.RaftMessages; @@ -74,12 +75,12 @@ else if ( msg instanceof RaftMessageDecoder.ClusterIdAwareMessageComposer ) } if ( messageComposer != null ) { - RaftMessages.ClusterIdAwareMessage clusterIdAwareMessage = messageComposer.maybeCompose( clock, raftLogEntryTerms, replicatedContents ); - if ( clusterIdAwareMessage != null ) + Optional clusterIdAwareMessage = messageComposer.maybeCompose( clock, raftLogEntryTerms, replicatedContents ); + clusterIdAwareMessage.ifPresent( message -> { - clear( clusterIdAwareMessage ); - out.add( clusterIdAwareMessage ); - } + clear( message ); + out.add( message ); + } ); } } diff --git a/enterprise/causal-clustering/src/main/java/org/neo4j/causalclustering/messaging/marshalling/v2/decoding/RaftMessageDecoder.java b/enterprise/causal-clustering/src/main/java/org/neo4j/causalclustering/messaging/marshalling/v2/decoding/RaftMessageDecoder.java index 8d07a50c738bc..8d0626839b501 100644 --- a/enterprise/causal-clustering/src/main/java/org/neo4j/causalclustering/messaging/marshalling/v2/decoding/RaftMessageDecoder.java +++ b/enterprise/causal-clustering/src/main/java/org/neo4j/causalclustering/messaging/marshalling/v2/decoding/RaftMessageDecoder.java @@ -29,11 +29,12 @@ import java.io.IOException; import java.time.Clock; import java.util.List; +import java.util.Optional; import java.util.Queue; -import java.util.function.BiFunction; import org.neo4j.causalclustering.catchup.Protocol; import org.neo4j.causalclustering.core.consensus.RaftMessages; +import org.neo4j.causalclustering.core.consensus.RaftMessages.ReceivedInstantClusterIdAwareMessage; import org.neo4j.causalclustering.core.consensus.log.RaftLogEntry; import org.neo4j.causalclustering.core.replication.ReplicatedContent; import org.neo4j.causalclustering.identity.ClusterId; @@ -84,14 +85,14 @@ public void decode( ChannelHandlerContext ctx, ByteBuf buffer, List list long lastLogIndex = channel.getLong(); long lastLogTerm = channel.getLong(); - composer = simpleMessageComposer( new RaftMessages.Vote.Request( from, term, candidate, lastLogIndex, lastLogTerm ) ); + composer = new SimpleMessageComposer( new RaftMessages.Vote.Request( from, term, candidate, lastLogIndex, lastLogTerm ) ); } else if ( messageType.equals( VOTE_RESPONSE ) ) { long term = channel.getLong(); boolean voteGranted = channel.get() == 1; - composer = simpleMessageComposer( new RaftMessages.Vote.Response( from, term, voteGranted ) ); + composer = new SimpleMessageComposer( new RaftMessages.Vote.Response( from, term, voteGranted ) ); } else if ( messageType.equals( PRE_VOTE_REQUEST ) ) { @@ -101,14 +102,14 @@ else if ( messageType.equals( PRE_VOTE_REQUEST ) ) long lastLogIndex = channel.getLong(); long lastLogTerm = channel.getLong(); - composer = simpleMessageComposer( new RaftMessages.PreVote.Request( from, term, candidate, lastLogIndex, lastLogTerm ) ); + composer = new SimpleMessageComposer( new RaftMessages.PreVote.Request( from, term, candidate, lastLogIndex, lastLogTerm ) ); } else if ( messageType.equals( PRE_VOTE_RESPONSE ) ) { long term = channel.getLong(); boolean voteGranted = channel.get() == 1; - composer = simpleMessageComposer( new RaftMessages.PreVote.Response( from, term, voteGranted ) ); + composer = new SimpleMessageComposer( new RaftMessages.PreVote.Response( from, term, voteGranted ) ); } else if ( messageType.equals( APPEND_ENTRIES_REQUEST ) ) { @@ -128,7 +129,7 @@ else if ( messageType.equals( APPEND_ENTRIES_RESPONSE ) ) long matchIndex = channel.getLong(); long appendIndex = channel.getLong(); - composer = simpleMessageComposer( new RaftMessages.AppendEntries.Response( from, term, success, matchIndex, appendIndex ) ); + composer = new SimpleMessageComposer( new RaftMessages.AppendEntries.Response( from, term, success, matchIndex, appendIndex ) ); } else if ( messageType.equals( NEW_ENTRY_REQUEST ) ) { @@ -140,18 +141,18 @@ else if ( messageType.equals( HEARTBEAT ) ) long commitIndexTerm = channel.getLong(); long commitIndex = channel.getLong(); - composer = simpleMessageComposer( new RaftMessages.Heartbeat( from, leaderTerm, commitIndex, commitIndexTerm ) ); + composer = new SimpleMessageComposer( new RaftMessages.Heartbeat( from, leaderTerm, commitIndex, commitIndexTerm ) ); } else if ( messageType.equals( HEARTBEAT_RESPONSE ) ) { - composer = simpleMessageComposer( new RaftMessages.HeartbeatResponse( from ) ); + composer = new SimpleMessageComposer( new RaftMessages.HeartbeatResponse( from ) ); } else if ( messageType.equals( LOG_COMPACTION_INFO ) ) { long leaderTerm = channel.getLong(); long prevIndex = channel.getLong(); - composer = simpleMessageComposer( new RaftMessages.LogCompactionInfo( from, leaderTerm, prevIndex ) ); + composer = new SimpleMessageComposer( new RaftMessages.LogCompactionInfo( from, leaderTerm, prevIndex ) ); } else { @@ -173,18 +174,10 @@ static class ClusterIdAwareMessageComposer this.clusterId = clusterId; } - RaftMessages.ClusterIdAwareMessage maybeCompose( Clock clock, Queue logEntryTerms, Queue replicatedContents ) + Optional maybeCompose( Clock clock, Queue terms, Queue contents ) { - RaftMessages.RaftMessage composedMessage = composer.apply( logEntryTerms, replicatedContents ); - - if ( composedMessage != null ) - { - return RaftMessages.ReceivedInstantClusterIdAwareMessage.of( clock.instant(), clusterId, composedMessage ); - } - else - { - return null; - } + return composer.maybeComplete( terms, contents ) + .map( m -> ReceivedInstantClusterIdAwareMessage.of( clock.instant(), clusterId, m ) ); } } @@ -194,19 +187,31 @@ private MemberId retrieveMember( ReadableChannel buffer ) throws IOException, En return memberIdMarshal.unmarshal( buffer ); } - /** - * Builds the raft message. Should return {@code null} if provided collections does not contain enough data for building the message. - */ - interface LazyComposer extends BiFunction,Queue,RaftMessages.RaftMessage> + interface LazyComposer { + /** + * Builds the complete raft message if provided collections contain enough data for building the complete message. + */ + Optional maybeComplete( Queue terms, Queue contents ); } /** - * A message without internal content components. + * A plain message without any more internal content. */ - private static LazyComposer simpleMessageComposer( RaftMessages.RaftMessage message ) + private static class SimpleMessageComposer implements LazyComposer { - return ( terms, contents ) -> message; + private final RaftMessages.RaftMessage message; + + private SimpleMessageComposer( RaftMessages.RaftMessage message ) + { + this.message = message; + } + + @Override + public Optional maybeComplete( Queue terms, Queue contents ) + { + return Optional.of( message ); + } } private static class AppendEntriesComposer implements LazyComposer @@ -229,26 +234,24 @@ private static class AppendEntriesComposer implements LazyComposer } @Override - public RaftMessages.BaseRaftMessage apply( Queue terms, Queue contents ) + public Optional maybeComplete( Queue terms, Queue contents ) { if ( terms.size() < entryCount || contents.size() < entryCount ) { - return null; + return Optional.empty(); } - else + + RaftLogEntry[] entries = new RaftLogEntry[entryCount]; + for ( int i = 0; i < entryCount; i++ ) { - RaftLogEntry[] entries = new RaftLogEntry[entryCount]; - for ( int i = 0; i < entryCount; i++ ) + Long term = terms.poll(); + if ( term == null ) { - Long term = terms.poll(); - if ( term == null ) - { - throw new IllegalArgumentException( "Term cannot be null" ); - } - entries[i] = new RaftLogEntry( term, contents.poll() ); + throw new IllegalArgumentException( "Term cannot be null" ); } - return new RaftMessages.AppendEntries.Request( from, term, prevLogIndex, prevLogTerm, entries, leaderCommit ); + entries[i] = new RaftLogEntry( term, contents.poll() ); } + return Optional.of( new RaftMessages.AppendEntries.Request( from, term, prevLogIndex, prevLogTerm, entries, leaderCommit ) ); } } @@ -262,15 +265,15 @@ private static class NewEntryRequestComposer implements LazyComposer } @Override - public RaftMessages.BaseRaftMessage apply( Queue terms, Queue contents ) + public Optional maybeComplete( Queue terms, Queue contents ) { if ( contents.isEmpty() ) { - return null; + return Optional.empty(); } else { - return new RaftMessages.NewEntry.Request( from, contents.poll() ); + return Optional.of( new RaftMessages.NewEntry.Request( from, contents.poll() ) ); } } } diff --git a/enterprise/causal-clustering/src/test/java/org/neo4j/causalclustering/messaging/marshalling/v2/decoding/ClusterIdAwareMessageComposerTest.java b/enterprise/causal-clustering/src/test/java/org/neo4j/causalclustering/messaging/marshalling/v2/decoding/ClusterIdAwareMessageComposerTest.java index ff3dd850ab4de..4c80f529b66c6 100644 --- a/enterprise/causal-clustering/src/test/java/org/neo4j/causalclustering/messaging/marshalling/v2/decoding/ClusterIdAwareMessageComposerTest.java +++ b/enterprise/causal-clustering/src/test/java/org/neo4j/causalclustering/messaging/marshalling/v2/decoding/ClusterIdAwareMessageComposerTest.java @@ -27,6 +27,7 @@ import java.time.Clock; import java.util.ArrayList; import java.util.List; +import java.util.Optional; import java.util.UUID; import org.neo4j.causalclustering.core.consensus.RaftMessages; @@ -40,7 +41,6 @@ public class ClusterIdAwareMessageComposerTest { - @Test public void shouldThrowExceptionOnConflictingMessageHeaders() { @@ -48,8 +48,8 @@ public void shouldThrowExceptionOnConflictingMessageHeaders() { RaftMessageComposer raftMessageComposer = new RaftMessageComposer( Clock.systemUTC() ); - raftMessageComposer.decode( null, messageCreator( ( a, b ) -> null ), null ); - raftMessageComposer.decode( null, messageCreator( ( a, b ) -> null ), null ); + raftMessageComposer.decode( null, messageCreator( ( a, b ) -> Optional.empty() ), null ); + raftMessageComposer.decode( null, messageCreator( ( a, b ) -> Optional.empty() ), null ); } catch ( IllegalStateException e ) { @@ -60,7 +60,7 @@ public void shouldThrowExceptionOnConflictingMessageHeaders() } @Test - public void shouldThrowExceptionIfNotAllResoucesAreUsed() + public void shouldThrowExceptionIfNotAllResourcesAreUsed() { try { @@ -69,7 +69,7 @@ public void shouldThrowExceptionIfNotAllResoucesAreUsed() ReplicatedTransaction replicatedTransaction = new ReplicatedTransaction( new byte[0] ); raftMessageComposer.decode( null, replicatedTransaction, null ); List out = new ArrayList<>(); - raftMessageComposer.decode( null, messageCreator( ( a, b ) -> dummyRequest() ), out ); + raftMessageComposer.decode( null, messageCreator( ( a, b ) -> Optional.of( dummyRequest() ) ), out ); } catch ( IllegalStateException e ) { @@ -102,8 +102,8 @@ private RaftMessages.PruneRequest dummyRequest() return new RaftMessages.PruneRequest( 1 ); } - private RaftMessageDecoder.ClusterIdAwareMessageComposer messageCreator( RaftMessageDecoder.LazyComposer biFunction ) + private RaftMessageDecoder.ClusterIdAwareMessageComposer messageCreator( RaftMessageDecoder.LazyComposer composer ) { - return new RaftMessageDecoder.ClusterIdAwareMessageComposer( biFunction, new ClusterId( UUID.randomUUID() ) ); + return new RaftMessageDecoder.ClusterIdAwareMessageComposer( composer, new ClusterId( UUID.randomUUID() ) ); } }