From 079bd23eb9db8c3383e2a9acf3d1d478c96cdfb2 Mon Sep 17 00:00:00 2001 From: RagnarW Date: Fri, 20 Jul 2018 15:29:38 +0200 Subject: [PATCH] Fix raft v1 TransactionRepresentationReplicatedTransaction does not have a defined size. This is required in raft v1. --- .../marshalling/v1/RaftMessageEncoder.java | 13 +- .../v2/RaftMessageEncoderDecoderTest.java | 148 +++++++++++++++--- 2 files changed, 136 insertions(+), 25 deletions(-) diff --git a/enterprise/causal-clustering/src/main/java/org/neo4j/causalclustering/messaging/marshalling/v1/RaftMessageEncoder.java b/enterprise/causal-clustering/src/main/java/org/neo4j/causalclustering/messaging/marshalling/v1/RaftMessageEncoder.java index 4a7c4ce565ea..94f5377ea51f 100644 --- a/enterprise/causal-clustering/src/main/java/org/neo4j/causalclustering/messaging/marshalling/v1/RaftMessageEncoder.java +++ b/enterprise/causal-clustering/src/main/java/org/neo4j/causalclustering/messaging/marshalling/v1/RaftMessageEncoder.java @@ -29,6 +29,7 @@ import org.neo4j.causalclustering.core.consensus.RaftMessages; import org.neo4j.causalclustering.core.consensus.log.RaftLogEntry; import org.neo4j.causalclustering.core.replication.ReplicatedContent; +import org.neo4j.causalclustering.core.state.machines.tx.TransactionRepresentationReplicatedTransaction; import org.neo4j.causalclustering.identity.ClusterId; import org.neo4j.causalclustering.identity.MemberId; import org.neo4j.causalclustering.messaging.NetworkFlushableByteBuf; @@ -146,7 +147,17 @@ public Void handle( RaftMessages.AppendEntries.Response appendResponse ) @Override public Void handle( RaftMessages.NewEntry.Request newEntryRequest ) throws Exception { - marshal.marshal( newEntryRequest.content(), channel ); + ReplicatedContent content = newEntryRequest.content(); + ByteBuf buffer = channel.buffer(); + int contentStartIndex = buffer.writerIndex() + 1; + marshal.marshal( content, channel ); + if ( content instanceof TransactionRepresentationReplicatedTransaction ) + { + // TransactionRepresentationReplicatedTransaction does not support marshal because it has unknown size + int contentEndIndex = buffer.writerIndex(); + int size = contentEndIndex - contentStartIndex - Integer.BYTES; // the integer is the length integer which should be excluded + buffer.setInt( contentStartIndex, size ); + } return null; } diff --git a/enterprise/causal-clustering/src/test/java/org/neo4j/causalclustering/messaging/marshalling/v2/RaftMessageEncoderDecoderTest.java b/enterprise/causal-clustering/src/test/java/org/neo4j/causalclustering/messaging/marshalling/v2/RaftMessageEncoderDecoderTest.java index 41886fc0d750..f1d5bfaa0e3b 100644 --- a/enterprise/causal-clustering/src/test/java/org/neo4j/causalclustering/messaging/marshalling/v2/RaftMessageEncoderDecoderTest.java +++ b/enterprise/causal-clustering/src/test/java/org/neo4j/causalclustering/messaging/marshalling/v2/RaftMessageEncoderDecoderTest.java @@ -22,6 +22,9 @@ */ package org.neo4j.causalclustering.messaging.marshalling.v2; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.buffer.UnpooledByteBufAllocator; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.SimpleChannelInboundHandler; import io.netty.channel.embedded.EmbeddedChannel; @@ -32,15 +35,22 @@ import org.junit.runner.RunWith; import org.junit.runners.Parameterized; +import java.io.IOException; import java.time.Instant; +import java.util.Arrays; import java.util.Collections; import java.util.UUID; +import java.util.function.Function; +import java.util.stream.Stream; import org.neo4j.causalclustering.core.consensus.RaftMessages; import org.neo4j.causalclustering.core.consensus.log.RaftLogEntry; +import org.neo4j.causalclustering.core.consensus.protocol.v1.RaftProtocolClientInstallerV1; +import org.neo4j.causalclustering.core.consensus.protocol.v1.RaftProtocolServerInstallerV1; import org.neo4j.causalclustering.core.consensus.protocol.v2.RaftProtocolClientInstallerV2; import org.neo4j.causalclustering.core.consensus.protocol.v2.RaftProtocolServerInstallerV2; import org.neo4j.causalclustering.core.replication.DistributedOperation; +import org.neo4j.causalclustering.core.replication.ReplicatedContent; import org.neo4j.causalclustering.core.replication.session.GlobalSession; import org.neo4j.causalclustering.core.replication.session.LocalOperationId; import org.neo4j.causalclustering.core.state.machines.locks.ReplicatedLockTokenRequest; @@ -50,40 +60,57 @@ import org.neo4j.causalclustering.handlers.VoidPipelineWrapperFactory; import org.neo4j.causalclustering.identity.ClusterId; import org.neo4j.causalclustering.identity.MemberId; +import org.neo4j.causalclustering.messaging.marshalling.ChunkedEncoder; import org.neo4j.causalclustering.protocol.NettyPipelineBuilderFactory; +import org.neo4j.kernel.impl.transaction.log.PhysicalTransactionRepresentation; import org.neo4j.logging.FormattedLogProvider; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNull; +/** + * Warning! This test ensures that all raft protocol work as expected in theirs current implementation. However, it does not know about changes to the + * protocols that breaks backward compatibility. + */ @RunWith( Parameterized.class ) public class RaftMessageEncoderDecoderTest { private static final MemberId MEMBER_ID = new MemberId( UUID.randomUUID() ); + private static final int[] PROTOCOLS = {1, 2}; @Parameterized.Parameter() public RaftMessages.RaftMessage raftMessage; + @Parameterized.Parameter( 1 ) + public int raftProtocol; private final RaftMessageHandler handler = new RaftMessageHandler(); - @Parameterized.Parameters( name = "{0}" ) - public static RaftMessages.RaftMessage[] data() + @Parameterized.Parameters( name = "Raft v{1} with message {0}" ) + public static Object[] data() { - return new RaftMessages.RaftMessage[]{ - new RaftMessages.Heartbeat( MEMBER_ID, 1, 2, 3 ), - new RaftMessages.HeartbeatResponse( MEMBER_ID ), + return setUpParams( new RaftMessages.RaftMessage[]{new RaftMessages.Heartbeat( MEMBER_ID, 1, 2, 3 ), new RaftMessages.HeartbeatResponse( MEMBER_ID ), new RaftMessages.NewEntry.Request( MEMBER_ID, ReplicatedTransaction.from( new byte[]{1, 2, 3, 4, 5, 6, 7, 8} ) ), - new RaftMessages.NewEntry.Request( MEMBER_ID, new DistributedOperation( - new DistributedOperation( ReplicatedTransaction.from( new byte[]{1, 2, 3, 4, 5} ), - new GlobalSession( UUID.randomUUID(), MEMBER_ID ), new LocalOperationId( 1, 2 ) ), - new GlobalSession( UUID.randomUUID(), MEMBER_ID ), new LocalOperationId( 3, 4 ) ) ), - new RaftMessages.AppendEntries.Request( MEMBER_ID, 1, 2, 3, new RaftLogEntry[]{ - new RaftLogEntry( 0, new ReplicatedTokenRequest( TokenType.LABEL, "name", new byte[]{2, 3, 4} ) ), - new RaftLogEntry( 1, new ReplicatedLockTokenRequest( MEMBER_ID, 2 ) )}, 5 ), - new RaftMessages.AppendEntries.Response( MEMBER_ID, 1, true, 2, 3 ), - new RaftMessages.Vote.Request( MEMBER_ID, Long.MAX_VALUE, MEMBER_ID, Long.MIN_VALUE, 1 ), - new RaftMessages.Vote.Response( MEMBER_ID, 1, true ), - new RaftMessages.PreVote.Request( MEMBER_ID, Long.MAX_VALUE, MEMBER_ID, Long.MIN_VALUE, 1 ), - new RaftMessages.PreVote.Response( MEMBER_ID, 1, true ), - new RaftMessages.LogCompactionInfo( MEMBER_ID, Long.MAX_VALUE, Long.MIN_VALUE )}; + new RaftMessages.NewEntry.Request( MEMBER_ID, ReplicatedTransaction.from( new PhysicalTransactionRepresentation( Collections.emptyList() ) ) ), + new RaftMessages.NewEntry.Request( MEMBER_ID, + ReplicatedTransaction.from( Unpooled.wrappedBuffer( new byte[]{1, 2, 3, 4, 5, 6, 7, 8} ).retain( 4 ) ) ), + new RaftMessages.NewEntry.Request( MEMBER_ID, new DistributedOperation( + new DistributedOperation( ReplicatedTransaction.from( new byte[]{1, 2, 3, 4, 5} ), new GlobalSession( UUID.randomUUID(), MEMBER_ID ), + new LocalOperationId( 1, 2 ) ), new GlobalSession( UUID.randomUUID(), MEMBER_ID ), new LocalOperationId( 3, 4 ) ) ), + new RaftMessages.AppendEntries.Request( MEMBER_ID, 1, 2, 3, + new RaftLogEntry[]{new RaftLogEntry( 0, new ReplicatedTokenRequest( TokenType.LABEL, "name", new byte[]{2, 3, 4} ) ), + new RaftLogEntry( 1, new ReplicatedLockTokenRequest( MEMBER_ID, 2 ) )}, 5 ), + new RaftMessages.AppendEntries.Response( MEMBER_ID, 1, true, 2, 3 ), + new RaftMessages.Vote.Request( MEMBER_ID, Long.MAX_VALUE, MEMBER_ID, Long.MIN_VALUE, 1 ), new RaftMessages.Vote.Response( MEMBER_ID, 1, true ), + new RaftMessages.PreVote.Request( MEMBER_ID, Long.MAX_VALUE, MEMBER_ID, Long.MIN_VALUE, 1 ), + new RaftMessages.PreVote.Response( MEMBER_ID, 1, true ), new RaftMessages.LogCompactionInfo( MEMBER_ID, Long.MAX_VALUE, Long.MIN_VALUE )} ); + } + + private static Object[] setUpParams( RaftMessages.RaftMessage[] messages ) + { + return Arrays.stream( messages ).flatMap( (Function>) RaftMessageEncoderDecoderTest::params ).toArray(); + } + + private static Stream params( RaftMessages.RaftMessage raftMessage ) + { + return Arrays.stream( PROTOCOLS ).mapToObj( p -> new Object[]{raftMessage, p} ); } private EmbeddedChannel outbound; @@ -95,10 +122,24 @@ public void setupChannels() throws Exception outbound = new EmbeddedChannel(); inbound = new EmbeddedChannel(); - new RaftProtocolClientInstallerV2( new NettyPipelineBuilderFactory( VoidPipelineWrapperFactory.VOID_WRAPPER ), Collections.emptyList(), - FormattedLogProvider.toOutputStream( System.out ) ).install( outbound ); - new RaftProtocolServerInstallerV2( handler, new NettyPipelineBuilderFactory( VoidPipelineWrapperFactory.VOID_WRAPPER ), Collections.emptyList(), - FormattedLogProvider.toOutputStream( System.out ) ).install( inbound ); + if ( raftProtocol == 2 ) + { + new RaftProtocolClientInstallerV2( new NettyPipelineBuilderFactory( VoidPipelineWrapperFactory.VOID_WRAPPER ), Collections.emptyList(), + FormattedLogProvider.toOutputStream( System.out ) ).install( outbound ); + new RaftProtocolServerInstallerV2( handler, new NettyPipelineBuilderFactory( VoidPipelineWrapperFactory.VOID_WRAPPER ), Collections.emptyList(), + FormattedLogProvider.toOutputStream( System.out ) ).install( inbound ); + } + else if ( raftProtocol == 1 ) + { + new RaftProtocolClientInstallerV1( new NettyPipelineBuilderFactory( VoidPipelineWrapperFactory.VOID_WRAPPER ), Collections.emptyList(), + FormattedLogProvider.toOutputStream( System.out ) ).install( outbound ); + new RaftProtocolServerInstallerV1( handler, new NettyPipelineBuilderFactory( VoidPipelineWrapperFactory.VOID_WRAPPER ), Collections.emptyList(), + FormattedLogProvider.toOutputStream( System.out ) ).install( inbound ); + } + else + { + throw new IllegalArgumentException( "Unknown raft protocol " + raftProtocol ); + } } @After @@ -116,7 +157,7 @@ public void cleanUp() } @Test - public void shouldEncodeDecodeRaftMessage() + public void shouldEncodeDecodeRaftMessage() throws IOException { ClusterId clusterId = new ClusterId( UUID.randomUUID() ); RaftMessages.ReceivedInstantClusterIdAwareMessage idAwareMessage = @@ -131,11 +172,70 @@ public void shouldEncodeDecodeRaftMessage() } RaftMessages.ReceivedInstantClusterIdAwareMessage message = handler.getRaftMessage(); assertEquals( clusterId, message.clusterId() ); - assertEquals( raftMessage, message.message() ); + raftMessageEquals( raftMessage, message.message() ); assertNull( inbound.readInbound() ); ReferenceCountUtil.release( handler.msg ); } + private void raftMessageEquals( RaftMessages.RaftMessage raftMessage, RaftMessages.RaftMessage message ) throws IOException + { + if ( raftMessage instanceof RaftMessages.NewEntry.Request ) + { + assertEquals( message.from(), raftMessage.from() ); + assertEquals( message.type(), raftMessage.type() ); + contentEquals( ((RaftMessages.NewEntry.Request) raftMessage).content(), ((RaftMessages.NewEntry.Request) raftMessage).content() ); + } + else if ( raftMessage instanceof RaftMessages.AppendEntries.Request ) + { + assertEquals( message.from(), raftMessage.from() ); + assertEquals( message.type(), raftMessage.type() ); + RaftLogEntry[] entries1 = ((RaftMessages.AppendEntries.Request) raftMessage).entries(); + RaftLogEntry[] entries2 = ((RaftMessages.AppendEntries.Request) message).entries(); + for ( int i = 0; i < entries1.length; i++ ) + { + RaftLogEntry raftLogEntry1 = entries1[i]; + RaftLogEntry raftLogEntry2 = entries2[i]; + assertEquals( raftLogEntry1.term(), raftLogEntry2.term() ); + contentEquals( raftLogEntry1.content(), raftLogEntry2.content() ); + } + } + } + + private void contentEquals( ReplicatedContent one, ReplicatedContent two ) throws IOException + { + if ( one instanceof ReplicatedTransaction ) + { + ByteBuf buffer1 = Unpooled.buffer(); + ByteBuf buffer2 = Unpooled.buffer(); + encode( buffer1, ((ReplicatedTransaction) one).marshal() ); + encode( buffer2, ((ReplicatedTransaction) two).marshal() ); + assertEquals( buffer1, buffer2 ); + } + else if ( one instanceof DistributedOperation ) + { + assertEquals( ((DistributedOperation) one).globalSession(), ((DistributedOperation) two).globalSession() ); + assertEquals( ((DistributedOperation) one).operationId(), ((DistributedOperation) two).operationId() ); + contentEquals( ((DistributedOperation) one).content(), ((DistributedOperation) two).content() ); + } + else + { + assertEquals( one, two ); + } + } + + private static void encode( ByteBuf buffer, ChunkedEncoder marshal ) throws IOException + { + while ( !marshal.isEndOfInput() ) + { + ByteBuf tmp = marshal.encodeChunk( UnpooledByteBufAllocator.DEFAULT ); + if ( tmp != null ) + { + buffer.writeBytes( tmp ); + tmp.release(); + } + } + } + class RaftMessageHandler extends SimpleChannelInboundHandler> {