diff --git a/enterprise/core-edge/src/main/java/org/neo4j/coreedge/raft/state/RaftState.java b/enterprise/core-edge/src/main/java/org/neo4j/coreedge/raft/state/RaftState.java index 2f44309d8ceda..5934fa4baf76f 100644 --- a/enterprise/core-edge/src/main/java/org/neo4j/coreedge/raft/state/RaftState.java +++ b/enterprise/core-edge/src/main/java/org/neo4j/coreedge/raft/state/RaftState.java @@ -128,12 +128,16 @@ public ReadableRaftLog entryLog() public void update( Outcome outcome ) throws RaftStorageException { - termState.update( outcome.getTerm() ); - voteState.votedFor( outcome.getVotedFor(), outcome.getTerm() ); try { - termStorage.persistStoreData( termState ); - voteStorage.persistStoreData( voteState ); + if ( termState.update( outcome.getTerm() ) ) + { + termStorage.persistStoreData( termState ); + } + if ( voteState.update( outcome.getVotedFor(), outcome.getTerm() ) ) + { + voteStorage.persistStoreData( voteState ); + } } catch ( IOException e ) { diff --git a/enterprise/core-edge/src/main/java/org/neo4j/coreedge/raft/state/term/TermState.java b/enterprise/core-edge/src/main/java/org/neo4j/coreedge/raft/state/term/TermState.java index b38dc74adb4d4..f66429d84f8db 100644 --- a/enterprise/core-edge/src/main/java/org/neo4j/coreedge/raft/state/term/TermState.java +++ b/enterprise/core-edge/src/main/java/org/neo4j/coreedge/raft/state/term/TermState.java @@ -52,10 +52,12 @@ public long currentTerm() * * @param newTerm The new value. */ - public void update( long newTerm ) + public boolean update( long newTerm ) { failIfInvalid( newTerm ); + boolean changed = term != newTerm; term = newTerm; + return changed; } /** diff --git a/enterprise/core-edge/src/main/java/org/neo4j/coreedge/raft/state/vote/VoteState.java b/enterprise/core-edge/src/main/java/org/neo4j/coreedge/raft/state/vote/VoteState.java index 3a99f5cb881ec..3a63a5704a58d 100644 --- a/enterprise/core-edge/src/main/java/org/neo4j/coreedge/raft/state/vote/VoteState.java +++ b/enterprise/core-edge/src/main/java/org/neo4j/coreedge/raft/state/vote/VoteState.java @@ -47,34 +47,37 @@ public MEMBER votedFor() return votedFor; } - public void votedFor( MEMBER votedFor, long term ) + public boolean update( MEMBER votedFor, long term ) { - assert ensureVoteIsUniquePerTerm( votedFor, term ) : "Votes for any instance should always be in more recent terms"; - - this.votedFor = votedFor; - this.term = term; - } - - private boolean ensureVoteIsUniquePerTerm( MEMBER votedFor, long term ) - { - if ( votedFor == null && this.votedFor == null ) - { - return true; - } - else if ( votedFor == null ) - { - return term > this.term; - } - else if ( this.votedFor == null ) + if ( termChanged( term ) ) { + this.votedFor = votedFor; + this.term = term; return true; } else { - return this.votedFor.equals( votedFor ) || term > this.term; + if ( this.votedFor == null ) + { + if ( votedFor != null ) + { + this.votedFor = votedFor; + return true; + } + } + else if ( !this.votedFor.equals( votedFor ) ) + { + throw new IllegalArgumentException( "Can only vote once per term." ); + } + return false; } } + private boolean termChanged( long term ) + { + return term != this.term; + } + public long term() { return term; diff --git a/enterprise/core-edge/src/test/java/org/neo4j/coreedge/raft/state/VoteStateTest.java b/enterprise/core-edge/src/test/java/org/neo4j/coreedge/raft/state/VoteStateTest.java index 434858971cd5f..54f1b5024bd72 100644 --- a/enterprise/core-edge/src/test/java/org/neo4j/coreedge/raft/state/VoteStateTest.java +++ b/enterprise/core-edge/src/test/java/org/neo4j/coreedge/raft/state/VoteStateTest.java @@ -26,7 +26,10 @@ import org.neo4j.coreedge.server.CoreMember; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; public class VoteStateTest { @@ -39,7 +42,7 @@ public void shouldStoreVote() throws Exception new AdvertisedSocketAddress( "host1:2001" ) ); // when - voteState.votedFor( member, 0 ); + voteState.update( member, 0 ); // then assertEquals( member, voteState.votedFor() ); @@ -66,8 +69,8 @@ public void shouldUpdateVote() throws Exception new AdvertisedSocketAddress( "host2:2001" ) ); // when - voteState.votedFor( member1, 0 ); - voteState.votedFor( member2, 1 ); + voteState.update( member1, 0 ); + voteState.update( member2, 1 ); // then assertEquals( member2, voteState.votedFor() ); @@ -80,12 +83,77 @@ public void shouldClearVote() throws Exception VoteState voteState = new VoteState<>(); CoreMember member = new CoreMember( new AdvertisedSocketAddress( "host1:1001" ), new AdvertisedSocketAddress( "host1:2001" ) ); - voteState.votedFor( member, 0 ); + voteState.update( member, 0 ); // when - voteState.votedFor( null, 1 ); + voteState.update( null, 1 ); // then assertNull( voteState.votedFor() ); } + + @Test + public void shouldNotUpdateVoteForSameTerm() throws Exception + { + // given + VoteState voteState = new VoteState<>(); + CoreMember member1 = new CoreMember( new AdvertisedSocketAddress( "host1:1001" ), + new AdvertisedSocketAddress( "host1:2001" ) ); + CoreMember member2 = new CoreMember( new AdvertisedSocketAddress( "host2:1001" ), + new AdvertisedSocketAddress( "host2:2001" ) ); + + voteState.update( member1, 0 ); + + try + { + // when + voteState.update( member2, 0 ); + fail( "Should have thrown IllegalArgumentException" ); + } + catch ( IllegalArgumentException expected ) + { + // expected + } + } + + @Test + public void shouldNotClearVoteForSameTerm() throws Exception + { + // given + VoteState voteState = new VoteState<>(); + CoreMember member1 = new CoreMember( new AdvertisedSocketAddress( "host1:1001" ), + new AdvertisedSocketAddress( "host1:2001" ) ); + + voteState.update( member1, 0 ); + + try + { + // when + voteState.update( null, 0 ); + fail( "Should have thrown IllegalArgumentException" ); + } + catch ( IllegalArgumentException expected ) + { + // expected + } + } + + @Test + public void shouldReportNoUpdateWhenVoteStateUnchanged() throws Exception + { + // given + VoteState voteState = new VoteState<>(); + CoreMember member1 = new CoreMember( new AdvertisedSocketAddress( "host1:1001" ), + new AdvertisedSocketAddress( "host1:2001" ) ); + CoreMember member2 = new CoreMember( new AdvertisedSocketAddress( "host2:1001" ), + new AdvertisedSocketAddress( "host2:2001" ) ); + + // when + assertTrue( voteState.update( null, 0 ) ); + assertFalse( voteState.update( null, 0 ) ); + assertTrue( voteState.update( member1, 0 ) ); + assertFalse( voteState.update( member1, 0 ) ); + assertTrue( voteState.update( member2, 1 ) ); + assertFalse( voteState.update( member2, 1 ) ); + } }