Skip to content

Commit

Permalink
Fix raft v1
Browse files Browse the repository at this point in the history
TransactionRepresentationReplicatedTransaction does not have a defined
size. This is required in raft v1.
  • Loading branch information
RagnarW authored and martinfurmanski committed Sep 10, 2018
1 parent 24428d8 commit 079bd23
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 25 deletions.
Expand Up @@ -29,6 +29,7 @@
import org.neo4j.causalclustering.core.consensus.RaftMessages;
import org.neo4j.causalclustering.core.consensus.log.RaftLogEntry;
import org.neo4j.causalclustering.core.replication.ReplicatedContent;
import org.neo4j.causalclustering.core.state.machines.tx.TransactionRepresentationReplicatedTransaction;
import org.neo4j.causalclustering.identity.ClusterId;
import org.neo4j.causalclustering.identity.MemberId;
import org.neo4j.causalclustering.messaging.NetworkFlushableByteBuf;
Expand Down Expand Up @@ -146,7 +147,17 @@ public Void handle( RaftMessages.AppendEntries.Response appendResponse )
@Override
public Void handle( RaftMessages.NewEntry.Request newEntryRequest ) throws Exception
{
marshal.marshal( newEntryRequest.content(), channel );
ReplicatedContent content = newEntryRequest.content();
ByteBuf buffer = channel.buffer();
int contentStartIndex = buffer.writerIndex() + 1;
marshal.marshal( content, channel );
if ( content instanceof TransactionRepresentationReplicatedTransaction )
{
// TransactionRepresentationReplicatedTransaction does not support marshal because it has unknown size
int contentEndIndex = buffer.writerIndex();
int size = contentEndIndex - contentStartIndex - Integer.BYTES; // the integer is the length integer which should be excluded
buffer.setInt( contentStartIndex, size );
}

return null;
}
Expand Down
Expand Up @@ -22,6 +22,9 @@
*/
package org.neo4j.causalclustering.messaging.marshalling.v2;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.buffer.UnpooledByteBufAllocator;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.embedded.EmbeddedChannel;
Expand All @@ -32,15 +35,22 @@
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;

import java.io.IOException;
import java.time.Instant;
import java.util.Arrays;
import java.util.Collections;
import java.util.UUID;
import java.util.function.Function;
import java.util.stream.Stream;

import org.neo4j.causalclustering.core.consensus.RaftMessages;
import org.neo4j.causalclustering.core.consensus.log.RaftLogEntry;
import org.neo4j.causalclustering.core.consensus.protocol.v1.RaftProtocolClientInstallerV1;
import org.neo4j.causalclustering.core.consensus.protocol.v1.RaftProtocolServerInstallerV1;
import org.neo4j.causalclustering.core.consensus.protocol.v2.RaftProtocolClientInstallerV2;
import org.neo4j.causalclustering.core.consensus.protocol.v2.RaftProtocolServerInstallerV2;
import org.neo4j.causalclustering.core.replication.DistributedOperation;
import org.neo4j.causalclustering.core.replication.ReplicatedContent;
import org.neo4j.causalclustering.core.replication.session.GlobalSession;
import org.neo4j.causalclustering.core.replication.session.LocalOperationId;
import org.neo4j.causalclustering.core.state.machines.locks.ReplicatedLockTokenRequest;
Expand All @@ -50,40 +60,57 @@
import org.neo4j.causalclustering.handlers.VoidPipelineWrapperFactory;
import org.neo4j.causalclustering.identity.ClusterId;
import org.neo4j.causalclustering.identity.MemberId;
import org.neo4j.causalclustering.messaging.marshalling.ChunkedEncoder;
import org.neo4j.causalclustering.protocol.NettyPipelineBuilderFactory;
import org.neo4j.kernel.impl.transaction.log.PhysicalTransactionRepresentation;
import org.neo4j.logging.FormattedLogProvider;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;

/**
* Warning! This test ensures that all raft protocol work as expected in theirs current implementation. However, it does not know about changes to the
* protocols that breaks backward compatibility.
*/
@RunWith( Parameterized.class )
public class RaftMessageEncoderDecoderTest
{
private static final MemberId MEMBER_ID = new MemberId( UUID.randomUUID() );
private static final int[] PROTOCOLS = {1, 2};
@Parameterized.Parameter()
public RaftMessages.RaftMessage raftMessage;
@Parameterized.Parameter( 1 )
public int raftProtocol;
private final RaftMessageHandler handler = new RaftMessageHandler();

@Parameterized.Parameters( name = "{0}" )
public static RaftMessages.RaftMessage[] data()
@Parameterized.Parameters( name = "Raft v{1} with message {0}" )
public static Object[] data()
{
return new RaftMessages.RaftMessage[]{
new RaftMessages.Heartbeat( MEMBER_ID, 1, 2, 3 ),
new RaftMessages.HeartbeatResponse( MEMBER_ID ),
return setUpParams( new RaftMessages.RaftMessage[]{new RaftMessages.Heartbeat( MEMBER_ID, 1, 2, 3 ), new RaftMessages.HeartbeatResponse( MEMBER_ID ),
new RaftMessages.NewEntry.Request( MEMBER_ID, ReplicatedTransaction.from( new byte[]{1, 2, 3, 4, 5, 6, 7, 8} ) ),
new RaftMessages.NewEntry.Request( MEMBER_ID, new DistributedOperation(
new DistributedOperation( ReplicatedTransaction.from( new byte[]{1, 2, 3, 4, 5} ),
new GlobalSession( UUID.randomUUID(), MEMBER_ID ), new LocalOperationId( 1, 2 ) ),
new GlobalSession( UUID.randomUUID(), MEMBER_ID ), new LocalOperationId( 3, 4 ) ) ),
new RaftMessages.AppendEntries.Request( MEMBER_ID, 1, 2, 3, new RaftLogEntry[]{
new RaftLogEntry( 0, new ReplicatedTokenRequest( TokenType.LABEL, "name", new byte[]{2, 3, 4} ) ),
new RaftLogEntry( 1, new ReplicatedLockTokenRequest( MEMBER_ID, 2 ) )}, 5 ),
new RaftMessages.AppendEntries.Response( MEMBER_ID, 1, true, 2, 3 ),
new RaftMessages.Vote.Request( MEMBER_ID, Long.MAX_VALUE, MEMBER_ID, Long.MIN_VALUE, 1 ),
new RaftMessages.Vote.Response( MEMBER_ID, 1, true ),
new RaftMessages.PreVote.Request( MEMBER_ID, Long.MAX_VALUE, MEMBER_ID, Long.MIN_VALUE, 1 ),
new RaftMessages.PreVote.Response( MEMBER_ID, 1, true ),
new RaftMessages.LogCompactionInfo( MEMBER_ID, Long.MAX_VALUE, Long.MIN_VALUE )};
new RaftMessages.NewEntry.Request( MEMBER_ID, ReplicatedTransaction.from( new PhysicalTransactionRepresentation( Collections.emptyList() ) ) ),
new RaftMessages.NewEntry.Request( MEMBER_ID,
ReplicatedTransaction.from( Unpooled.wrappedBuffer( new byte[]{1, 2, 3, 4, 5, 6, 7, 8} ).retain( 4 ) ) ),
new RaftMessages.NewEntry.Request( MEMBER_ID, new DistributedOperation(
new DistributedOperation( ReplicatedTransaction.from( new byte[]{1, 2, 3, 4, 5} ), new GlobalSession( UUID.randomUUID(), MEMBER_ID ),
new LocalOperationId( 1, 2 ) ), new GlobalSession( UUID.randomUUID(), MEMBER_ID ), new LocalOperationId( 3, 4 ) ) ),
new RaftMessages.AppendEntries.Request( MEMBER_ID, 1, 2, 3,
new RaftLogEntry[]{new RaftLogEntry( 0, new ReplicatedTokenRequest( TokenType.LABEL, "name", new byte[]{2, 3, 4} ) ),
new RaftLogEntry( 1, new ReplicatedLockTokenRequest( MEMBER_ID, 2 ) )}, 5 ),
new RaftMessages.AppendEntries.Response( MEMBER_ID, 1, true, 2, 3 ),
new RaftMessages.Vote.Request( MEMBER_ID, Long.MAX_VALUE, MEMBER_ID, Long.MIN_VALUE, 1 ), new RaftMessages.Vote.Response( MEMBER_ID, 1, true ),
new RaftMessages.PreVote.Request( MEMBER_ID, Long.MAX_VALUE, MEMBER_ID, Long.MIN_VALUE, 1 ),
new RaftMessages.PreVote.Response( MEMBER_ID, 1, true ), new RaftMessages.LogCompactionInfo( MEMBER_ID, Long.MAX_VALUE, Long.MIN_VALUE )} );
}

private static Object[] setUpParams( RaftMessages.RaftMessage[] messages )
{
return Arrays.stream( messages ).flatMap( (Function<RaftMessages.RaftMessage,Stream<?>>) RaftMessageEncoderDecoderTest::params ).toArray();
}

private static Stream<Object[]> params( RaftMessages.RaftMessage raftMessage )
{
return Arrays.stream( PROTOCOLS ).mapToObj( p -> new Object[]{raftMessage, p} );
}

private EmbeddedChannel outbound;
Expand All @@ -95,10 +122,24 @@ public void setupChannels() throws Exception
outbound = new EmbeddedChannel();
inbound = new EmbeddedChannel();

new RaftProtocolClientInstallerV2( new NettyPipelineBuilderFactory( VoidPipelineWrapperFactory.VOID_WRAPPER ), Collections.emptyList(),
FormattedLogProvider.toOutputStream( System.out ) ).install( outbound );
new RaftProtocolServerInstallerV2( handler, new NettyPipelineBuilderFactory( VoidPipelineWrapperFactory.VOID_WRAPPER ), Collections.emptyList(),
FormattedLogProvider.toOutputStream( System.out ) ).install( inbound );
if ( raftProtocol == 2 )
{
new RaftProtocolClientInstallerV2( new NettyPipelineBuilderFactory( VoidPipelineWrapperFactory.VOID_WRAPPER ), Collections.emptyList(),
FormattedLogProvider.toOutputStream( System.out ) ).install( outbound );
new RaftProtocolServerInstallerV2( handler, new NettyPipelineBuilderFactory( VoidPipelineWrapperFactory.VOID_WRAPPER ), Collections.emptyList(),
FormattedLogProvider.toOutputStream( System.out ) ).install( inbound );
}
else if ( raftProtocol == 1 )
{
new RaftProtocolClientInstallerV1( new NettyPipelineBuilderFactory( VoidPipelineWrapperFactory.VOID_WRAPPER ), Collections.emptyList(),
FormattedLogProvider.toOutputStream( System.out ) ).install( outbound );
new RaftProtocolServerInstallerV1( handler, new NettyPipelineBuilderFactory( VoidPipelineWrapperFactory.VOID_WRAPPER ), Collections.emptyList(),
FormattedLogProvider.toOutputStream( System.out ) ).install( inbound );
}
else
{
throw new IllegalArgumentException( "Unknown raft protocol " + raftProtocol );
}
}

@After
Expand All @@ -116,7 +157,7 @@ public void cleanUp()
}

@Test
public void shouldEncodeDecodeRaftMessage()
public void shouldEncodeDecodeRaftMessage() throws IOException
{
ClusterId clusterId = new ClusterId( UUID.randomUUID() );
RaftMessages.ReceivedInstantClusterIdAwareMessage<RaftMessages.RaftMessage> idAwareMessage =
Expand All @@ -131,11 +172,70 @@ public void shouldEncodeDecodeRaftMessage()
}
RaftMessages.ReceivedInstantClusterIdAwareMessage<RaftMessages.RaftMessage> message = handler.getRaftMessage();
assertEquals( clusterId, message.clusterId() );
assertEquals( raftMessage, message.message() );
raftMessageEquals( raftMessage, message.message() );
assertNull( inbound.readInbound() );
ReferenceCountUtil.release( handler.msg );
}

private void raftMessageEquals( RaftMessages.RaftMessage raftMessage, RaftMessages.RaftMessage message ) throws IOException
{
if ( raftMessage instanceof RaftMessages.NewEntry.Request )
{
assertEquals( message.from(), raftMessage.from() );
assertEquals( message.type(), raftMessage.type() );
contentEquals( ((RaftMessages.NewEntry.Request) raftMessage).content(), ((RaftMessages.NewEntry.Request) raftMessage).content() );
}
else if ( raftMessage instanceof RaftMessages.AppendEntries.Request )
{
assertEquals( message.from(), raftMessage.from() );
assertEquals( message.type(), raftMessage.type() );
RaftLogEntry[] entries1 = ((RaftMessages.AppendEntries.Request) raftMessage).entries();
RaftLogEntry[] entries2 = ((RaftMessages.AppendEntries.Request) message).entries();
for ( int i = 0; i < entries1.length; i++ )
{
RaftLogEntry raftLogEntry1 = entries1[i];
RaftLogEntry raftLogEntry2 = entries2[i];
assertEquals( raftLogEntry1.term(), raftLogEntry2.term() );
contentEquals( raftLogEntry1.content(), raftLogEntry2.content() );
}
}
}

private void contentEquals( ReplicatedContent one, ReplicatedContent two ) throws IOException
{
if ( one instanceof ReplicatedTransaction )
{
ByteBuf buffer1 = Unpooled.buffer();
ByteBuf buffer2 = Unpooled.buffer();
encode( buffer1, ((ReplicatedTransaction) one).marshal() );
encode( buffer2, ((ReplicatedTransaction) two).marshal() );
assertEquals( buffer1, buffer2 );
}
else if ( one instanceof DistributedOperation )
{
assertEquals( ((DistributedOperation) one).globalSession(), ((DistributedOperation) two).globalSession() );
assertEquals( ((DistributedOperation) one).operationId(), ((DistributedOperation) two).operationId() );
contentEquals( ((DistributedOperation) one).content(), ((DistributedOperation) two).content() );
}
else
{
assertEquals( one, two );
}
}

private static void encode( ByteBuf buffer, ChunkedEncoder marshal ) throws IOException
{
while ( !marshal.isEndOfInput() )
{
ByteBuf tmp = marshal.encodeChunk( UnpooledByteBufAllocator.DEFAULT );
if ( tmp != null )
{
buffer.writeBytes( tmp );
tmp.release();
}
}
}

class RaftMessageHandler extends SimpleChannelInboundHandler<RaftMessages.ReceivedInstantClusterIdAwareMessage<RaftMessages.RaftMessage>>
{

Expand Down

0 comments on commit 079bd23

Please sign in to comment.