Skip to content

Commit

Permalink
Fix raft v1
Browse files Browse the repository at this point in the history
TransactionRepresentationReplicatedTransaction does not have a defined
size. This is required in raft v1.
  • Loading branch information
RagnarW authored and martinfurmanski committed Sep 10, 2018
1 parent 24428d8 commit 079bd23
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 25 deletions.
Expand Up @@ -29,6 +29,7 @@
import org.neo4j.causalclustering.core.consensus.RaftMessages; import org.neo4j.causalclustering.core.consensus.RaftMessages;
import org.neo4j.causalclustering.core.consensus.log.RaftLogEntry; import org.neo4j.causalclustering.core.consensus.log.RaftLogEntry;
import org.neo4j.causalclustering.core.replication.ReplicatedContent; import org.neo4j.causalclustering.core.replication.ReplicatedContent;
import org.neo4j.causalclustering.core.state.machines.tx.TransactionRepresentationReplicatedTransaction;
import org.neo4j.causalclustering.identity.ClusterId; import org.neo4j.causalclustering.identity.ClusterId;
import org.neo4j.causalclustering.identity.MemberId; import org.neo4j.causalclustering.identity.MemberId;
import org.neo4j.causalclustering.messaging.NetworkFlushableByteBuf; import org.neo4j.causalclustering.messaging.NetworkFlushableByteBuf;
Expand Down Expand Up @@ -146,7 +147,17 @@ public Void handle( RaftMessages.AppendEntries.Response appendResponse )
@Override @Override
public Void handle( RaftMessages.NewEntry.Request newEntryRequest ) throws Exception public Void handle( RaftMessages.NewEntry.Request newEntryRequest ) throws Exception
{ {
marshal.marshal( newEntryRequest.content(), channel ); ReplicatedContent content = newEntryRequest.content();
ByteBuf buffer = channel.buffer();
int contentStartIndex = buffer.writerIndex() + 1;
marshal.marshal( content, channel );
if ( content instanceof TransactionRepresentationReplicatedTransaction )
{
// TransactionRepresentationReplicatedTransaction does not support marshal because it has unknown size
int contentEndIndex = buffer.writerIndex();
int size = contentEndIndex - contentStartIndex - Integer.BYTES; // the integer is the length integer which should be excluded
buffer.setInt( contentStartIndex, size );
}


return null; return null;
} }
Expand Down
Expand Up @@ -22,6 +22,9 @@
*/ */
package org.neo4j.causalclustering.messaging.marshalling.v2; package org.neo4j.causalclustering.messaging.marshalling.v2;


import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.buffer.UnpooledByteBufAllocator;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler; import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.embedded.EmbeddedChannel; import io.netty.channel.embedded.EmbeddedChannel;
Expand All @@ -32,15 +35,22 @@
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.junit.runners.Parameterized; import org.junit.runners.Parameterized;


import java.io.IOException;
import java.time.Instant; import java.time.Instant;
import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.UUID; import java.util.UUID;
import java.util.function.Function;
import java.util.stream.Stream;


import org.neo4j.causalclustering.core.consensus.RaftMessages; import org.neo4j.causalclustering.core.consensus.RaftMessages;
import org.neo4j.causalclustering.core.consensus.log.RaftLogEntry; import org.neo4j.causalclustering.core.consensus.log.RaftLogEntry;
import org.neo4j.causalclustering.core.consensus.protocol.v1.RaftProtocolClientInstallerV1;
import org.neo4j.causalclustering.core.consensus.protocol.v1.RaftProtocolServerInstallerV1;
import org.neo4j.causalclustering.core.consensus.protocol.v2.RaftProtocolClientInstallerV2; import org.neo4j.causalclustering.core.consensus.protocol.v2.RaftProtocolClientInstallerV2;
import org.neo4j.causalclustering.core.consensus.protocol.v2.RaftProtocolServerInstallerV2; import org.neo4j.causalclustering.core.consensus.protocol.v2.RaftProtocolServerInstallerV2;
import org.neo4j.causalclustering.core.replication.DistributedOperation; import org.neo4j.causalclustering.core.replication.DistributedOperation;
import org.neo4j.causalclustering.core.replication.ReplicatedContent;
import org.neo4j.causalclustering.core.replication.session.GlobalSession; import org.neo4j.causalclustering.core.replication.session.GlobalSession;
import org.neo4j.causalclustering.core.replication.session.LocalOperationId; import org.neo4j.causalclustering.core.replication.session.LocalOperationId;
import org.neo4j.causalclustering.core.state.machines.locks.ReplicatedLockTokenRequest; import org.neo4j.causalclustering.core.state.machines.locks.ReplicatedLockTokenRequest;
Expand All @@ -50,40 +60,57 @@
import org.neo4j.causalclustering.handlers.VoidPipelineWrapperFactory; import org.neo4j.causalclustering.handlers.VoidPipelineWrapperFactory;
import org.neo4j.causalclustering.identity.ClusterId; import org.neo4j.causalclustering.identity.ClusterId;
import org.neo4j.causalclustering.identity.MemberId; import org.neo4j.causalclustering.identity.MemberId;
import org.neo4j.causalclustering.messaging.marshalling.ChunkedEncoder;
import org.neo4j.causalclustering.protocol.NettyPipelineBuilderFactory; import org.neo4j.causalclustering.protocol.NettyPipelineBuilderFactory;
import org.neo4j.kernel.impl.transaction.log.PhysicalTransactionRepresentation;
import org.neo4j.logging.FormattedLogProvider; import org.neo4j.logging.FormattedLogProvider;


import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull; import static org.junit.Assert.assertNull;


/**
* Warning! This test ensures that all raft protocol work as expected in theirs current implementation. However, it does not know about changes to the
* protocols that breaks backward compatibility.
*/
@RunWith( Parameterized.class ) @RunWith( Parameterized.class )
public class RaftMessageEncoderDecoderTest public class RaftMessageEncoderDecoderTest
{ {
private static final MemberId MEMBER_ID = new MemberId( UUID.randomUUID() ); private static final MemberId MEMBER_ID = new MemberId( UUID.randomUUID() );
private static final int[] PROTOCOLS = {1, 2};
@Parameterized.Parameter() @Parameterized.Parameter()
public RaftMessages.RaftMessage raftMessage; public RaftMessages.RaftMessage raftMessage;
@Parameterized.Parameter( 1 )
public int raftProtocol;
private final RaftMessageHandler handler = new RaftMessageHandler(); private final RaftMessageHandler handler = new RaftMessageHandler();


@Parameterized.Parameters( name = "{0}" ) @Parameterized.Parameters( name = "Raft v{1} with message {0}" )
public static RaftMessages.RaftMessage[] data() public static Object[] data()
{ {
return new RaftMessages.RaftMessage[]{ return setUpParams( new RaftMessages.RaftMessage[]{new RaftMessages.Heartbeat( MEMBER_ID, 1, 2, 3 ), new RaftMessages.HeartbeatResponse( MEMBER_ID ),
new RaftMessages.Heartbeat( MEMBER_ID, 1, 2, 3 ),
new RaftMessages.HeartbeatResponse( MEMBER_ID ),
new RaftMessages.NewEntry.Request( MEMBER_ID, ReplicatedTransaction.from( new byte[]{1, 2, 3, 4, 5, 6, 7, 8} ) ), new RaftMessages.NewEntry.Request( MEMBER_ID, ReplicatedTransaction.from( new byte[]{1, 2, 3, 4, 5, 6, 7, 8} ) ),
new RaftMessages.NewEntry.Request( MEMBER_ID, new DistributedOperation( new RaftMessages.NewEntry.Request( MEMBER_ID, ReplicatedTransaction.from( new PhysicalTransactionRepresentation( Collections.emptyList() ) ) ),
new DistributedOperation( ReplicatedTransaction.from( new byte[]{1, 2, 3, 4, 5} ), new RaftMessages.NewEntry.Request( MEMBER_ID,
new GlobalSession( UUID.randomUUID(), MEMBER_ID ), new LocalOperationId( 1, 2 ) ), ReplicatedTransaction.from( Unpooled.wrappedBuffer( new byte[]{1, 2, 3, 4, 5, 6, 7, 8} ).retain( 4 ) ) ),
new GlobalSession( UUID.randomUUID(), MEMBER_ID ), new LocalOperationId( 3, 4 ) ) ), new RaftMessages.NewEntry.Request( MEMBER_ID, new DistributedOperation(
new RaftMessages.AppendEntries.Request( MEMBER_ID, 1, 2, 3, new RaftLogEntry[]{ new DistributedOperation( ReplicatedTransaction.from( new byte[]{1, 2, 3, 4, 5} ), new GlobalSession( UUID.randomUUID(), MEMBER_ID ),
new RaftLogEntry( 0, new ReplicatedTokenRequest( TokenType.LABEL, "name", new byte[]{2, 3, 4} ) ), new LocalOperationId( 1, 2 ) ), new GlobalSession( UUID.randomUUID(), MEMBER_ID ), new LocalOperationId( 3, 4 ) ) ),
new RaftLogEntry( 1, new ReplicatedLockTokenRequest( MEMBER_ID, 2 ) )}, 5 ), new RaftMessages.AppendEntries.Request( MEMBER_ID, 1, 2, 3,
new RaftMessages.AppendEntries.Response( MEMBER_ID, 1, true, 2, 3 ), new RaftLogEntry[]{new RaftLogEntry( 0, new ReplicatedTokenRequest( TokenType.LABEL, "name", new byte[]{2, 3, 4} ) ),
new RaftMessages.Vote.Request( MEMBER_ID, Long.MAX_VALUE, MEMBER_ID, Long.MIN_VALUE, 1 ), new RaftLogEntry( 1, new ReplicatedLockTokenRequest( MEMBER_ID, 2 ) )}, 5 ),
new RaftMessages.Vote.Response( MEMBER_ID, 1, true ), new RaftMessages.AppendEntries.Response( MEMBER_ID, 1, true, 2, 3 ),
new RaftMessages.PreVote.Request( MEMBER_ID, Long.MAX_VALUE, MEMBER_ID, Long.MIN_VALUE, 1 ), new RaftMessages.Vote.Request( MEMBER_ID, Long.MAX_VALUE, MEMBER_ID, Long.MIN_VALUE, 1 ), new RaftMessages.Vote.Response( MEMBER_ID, 1, true ),
new RaftMessages.PreVote.Response( MEMBER_ID, 1, true ), new RaftMessages.PreVote.Request( MEMBER_ID, Long.MAX_VALUE, MEMBER_ID, Long.MIN_VALUE, 1 ),
new RaftMessages.LogCompactionInfo( MEMBER_ID, Long.MAX_VALUE, Long.MIN_VALUE )}; new RaftMessages.PreVote.Response( MEMBER_ID, 1, true ), new RaftMessages.LogCompactionInfo( MEMBER_ID, Long.MAX_VALUE, Long.MIN_VALUE )} );
}

private static Object[] setUpParams( RaftMessages.RaftMessage[] messages )
{
return Arrays.stream( messages ).flatMap( (Function<RaftMessages.RaftMessage,Stream<?>>) RaftMessageEncoderDecoderTest::params ).toArray();
}

private static Stream<Object[]> params( RaftMessages.RaftMessage raftMessage )
{
return Arrays.stream( PROTOCOLS ).mapToObj( p -> new Object[]{raftMessage, p} );
} }


private EmbeddedChannel outbound; private EmbeddedChannel outbound;
Expand All @@ -95,10 +122,24 @@ public void setupChannels() throws Exception
outbound = new EmbeddedChannel(); outbound = new EmbeddedChannel();
inbound = new EmbeddedChannel(); inbound = new EmbeddedChannel();


new RaftProtocolClientInstallerV2( new NettyPipelineBuilderFactory( VoidPipelineWrapperFactory.VOID_WRAPPER ), Collections.emptyList(), if ( raftProtocol == 2 )
FormattedLogProvider.toOutputStream( System.out ) ).install( outbound ); {
new RaftProtocolServerInstallerV2( handler, new NettyPipelineBuilderFactory( VoidPipelineWrapperFactory.VOID_WRAPPER ), Collections.emptyList(), new RaftProtocolClientInstallerV2( new NettyPipelineBuilderFactory( VoidPipelineWrapperFactory.VOID_WRAPPER ), Collections.emptyList(),
FormattedLogProvider.toOutputStream( System.out ) ).install( inbound ); FormattedLogProvider.toOutputStream( System.out ) ).install( outbound );
new RaftProtocolServerInstallerV2( handler, new NettyPipelineBuilderFactory( VoidPipelineWrapperFactory.VOID_WRAPPER ), Collections.emptyList(),
FormattedLogProvider.toOutputStream( System.out ) ).install( inbound );
}
else if ( raftProtocol == 1 )
{
new RaftProtocolClientInstallerV1( new NettyPipelineBuilderFactory( VoidPipelineWrapperFactory.VOID_WRAPPER ), Collections.emptyList(),
FormattedLogProvider.toOutputStream( System.out ) ).install( outbound );
new RaftProtocolServerInstallerV1( handler, new NettyPipelineBuilderFactory( VoidPipelineWrapperFactory.VOID_WRAPPER ), Collections.emptyList(),
FormattedLogProvider.toOutputStream( System.out ) ).install( inbound );
}
else
{
throw new IllegalArgumentException( "Unknown raft protocol " + raftProtocol );
}
} }


@After @After
Expand All @@ -116,7 +157,7 @@ public void cleanUp()
} }


@Test @Test
public void shouldEncodeDecodeRaftMessage() public void shouldEncodeDecodeRaftMessage() throws IOException
{ {
ClusterId clusterId = new ClusterId( UUID.randomUUID() ); ClusterId clusterId = new ClusterId( UUID.randomUUID() );
RaftMessages.ReceivedInstantClusterIdAwareMessage<RaftMessages.RaftMessage> idAwareMessage = RaftMessages.ReceivedInstantClusterIdAwareMessage<RaftMessages.RaftMessage> idAwareMessage =
Expand All @@ -131,11 +172,70 @@ public void shouldEncodeDecodeRaftMessage()
} }
RaftMessages.ReceivedInstantClusterIdAwareMessage<RaftMessages.RaftMessage> message = handler.getRaftMessage(); RaftMessages.ReceivedInstantClusterIdAwareMessage<RaftMessages.RaftMessage> message = handler.getRaftMessage();
assertEquals( clusterId, message.clusterId() ); assertEquals( clusterId, message.clusterId() );
assertEquals( raftMessage, message.message() ); raftMessageEquals( raftMessage, message.message() );
assertNull( inbound.readInbound() ); assertNull( inbound.readInbound() );
ReferenceCountUtil.release( handler.msg ); ReferenceCountUtil.release( handler.msg );
} }


private void raftMessageEquals( RaftMessages.RaftMessage raftMessage, RaftMessages.RaftMessage message ) throws IOException
{
if ( raftMessage instanceof RaftMessages.NewEntry.Request )
{
assertEquals( message.from(), raftMessage.from() );
assertEquals( message.type(), raftMessage.type() );
contentEquals( ((RaftMessages.NewEntry.Request) raftMessage).content(), ((RaftMessages.NewEntry.Request) raftMessage).content() );
}
else if ( raftMessage instanceof RaftMessages.AppendEntries.Request )
{
assertEquals( message.from(), raftMessage.from() );
assertEquals( message.type(), raftMessage.type() );
RaftLogEntry[] entries1 = ((RaftMessages.AppendEntries.Request) raftMessage).entries();
RaftLogEntry[] entries2 = ((RaftMessages.AppendEntries.Request) message).entries();
for ( int i = 0; i < entries1.length; i++ )
{
RaftLogEntry raftLogEntry1 = entries1[i];
RaftLogEntry raftLogEntry2 = entries2[i];
assertEquals( raftLogEntry1.term(), raftLogEntry2.term() );
contentEquals( raftLogEntry1.content(), raftLogEntry2.content() );
}
}
}

private void contentEquals( ReplicatedContent one, ReplicatedContent two ) throws IOException
{
if ( one instanceof ReplicatedTransaction )
{
ByteBuf buffer1 = Unpooled.buffer();
ByteBuf buffer2 = Unpooled.buffer();
encode( buffer1, ((ReplicatedTransaction) one).marshal() );
encode( buffer2, ((ReplicatedTransaction) two).marshal() );
assertEquals( buffer1, buffer2 );
}
else if ( one instanceof DistributedOperation )
{
assertEquals( ((DistributedOperation) one).globalSession(), ((DistributedOperation) two).globalSession() );
assertEquals( ((DistributedOperation) one).operationId(), ((DistributedOperation) two).operationId() );
contentEquals( ((DistributedOperation) one).content(), ((DistributedOperation) two).content() );
}
else
{
assertEquals( one, two );
}
}

private static void encode( ByteBuf buffer, ChunkedEncoder marshal ) throws IOException
{
while ( !marshal.isEndOfInput() )
{
ByteBuf tmp = marshal.encodeChunk( UnpooledByteBufAllocator.DEFAULT );
if ( tmp != null )
{
buffer.writeBytes( tmp );
tmp.release();
}
}
}

class RaftMessageHandler extends SimpleChannelInboundHandler<RaftMessages.ReceivedInstantClusterIdAwareMessage<RaftMessages.RaftMessage>> class RaftMessageHandler extends SimpleChannelInboundHandler<RaftMessages.ReceivedInstantClusterIdAwareMessage<RaftMessages.RaftMessage>>
{ {


Expand Down

0 comments on commit 079bd23

Please sign in to comment.