Skip to content

Commit

Permalink
Assert local database is healthy in RaftReplicator
Browse files Browse the repository at this point in the history
This ensures that the replication is aborted if database is not healthy.
  • Loading branch information
RagnarW committed Oct 11, 2018
1 parent e0fee18 commit da3a460
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 37 deletions.
Expand Up @@ -26,6 +26,7 @@
import java.time.Duration; import java.time.Duration;
import java.util.UUID; import java.util.UUID;


import org.neo4j.causalclustering.catchup.storecopy.LocalDatabase;
import org.neo4j.causalclustering.core.CausalClusteringSettings; import org.neo4j.causalclustering.core.CausalClusteringSettings;
import org.neo4j.causalclustering.core.consensus.RaftMachine; import org.neo4j.causalclustering.core.consensus.RaftMachine;
import org.neo4j.causalclustering.core.consensus.RaftMessages; import org.neo4j.causalclustering.core.consensus.RaftMessages;
Expand Down Expand Up @@ -55,8 +56,8 @@ public class ReplicationModule
private final SessionTracker sessionTracker; private final SessionTracker sessionTracker;


public ReplicationModule( RaftMachine raftMachine, MemberId myself, PlatformModule platformModule, Config config, public ReplicationModule( RaftMachine raftMachine, MemberId myself, PlatformModule platformModule, Config config,
Outbound<MemberId,RaftMessages.RaftMessage> outbound, Outbound<MemberId,RaftMessages.RaftMessage> outbound, File clusterStateDirectory, FileSystemAbstraction fileSystem, LogProvider logProvider,
File clusterStateDirectory, FileSystemAbstraction fileSystem, LogProvider logProvider, AvailabilityGuard globalAvailabilityGuard ) AvailabilityGuard globalAvailabilityGuard, LocalDatabase localDatabase )
{ {
LifeSupport life = platformModule.life; LifeSupport life = platformModule.life;


Expand All @@ -82,7 +83,7 @@ public ReplicationModule( RaftMachine raftMachine, MemberId myself, PlatformModu
outbound, outbound,
sessionPool, sessionPool,
progressTracker, progressRetryStrategy, availabilityTimeoutMillis, progressTracker, progressRetryStrategy, availabilityTimeoutMillis,
globalAvailabilityGuard, logProvider, globalAvailabilityGuard, logProvider, localDatabase,
platformModule.monitors ); platformModule.monitors );
} }


Expand Down
Expand Up @@ -299,7 +299,7 @@ public EnterpriseCoreEditionModule( final PlatformModule platformModule,
dependencies.satisfyDependency( consensusModule.raftMachine() ); dependencies.satisfyDependency( consensusModule.raftMachine() );


replicationModule = new ReplicationModule( consensusModule.raftMachine(), identityModule.myself(), platformModule, config, loggingOutbound, replicationModule = new ReplicationModule( consensusModule.raftMachine(), identityModule.myself(), platformModule, config, loggingOutbound,
clusterStateDirectory.get(), fileSystem, logProvider, globalGuard ); clusterStateDirectory.get(), fileSystem, logProvider, globalGuard, localDatabase );


coreStateMachinesModule = new CoreStateMachinesModule( identityModule.myself(), coreStateMachinesModule = new CoreStateMachinesModule( identityModule.myself(),
platformModule, clusterStateDirectory.get(), config, replicationModule.getReplicator(), platformModule, clusterStateDirectory.get(), config, replicationModule.getReplicator(),
Expand Down
Expand Up @@ -25,6 +25,7 @@
import java.util.concurrent.Future; import java.util.concurrent.Future;
import java.util.function.BiConsumer; import java.util.function.BiConsumer;


import org.neo4j.causalclustering.catchup.storecopy.LocalDatabase;
import org.neo4j.causalclustering.core.consensus.LeaderInfo; import org.neo4j.causalclustering.core.consensus.LeaderInfo;
import org.neo4j.causalclustering.core.consensus.LeaderListener; import org.neo4j.causalclustering.core.consensus.LeaderListener;
import org.neo4j.causalclustering.core.consensus.LeaderLocator; import org.neo4j.causalclustering.core.consensus.LeaderLocator;
Expand Down Expand Up @@ -53,13 +54,14 @@ public class RaftReplicator implements Replicator, LeaderListener
private final TimeoutStrategy progressTimeoutStrategy; private final TimeoutStrategy progressTimeoutStrategy;
private final AvailabilityGuard availabilityGuard; private final AvailabilityGuard availabilityGuard;
private final Log log; private final Log log;
private final LocalDatabase localDatabase;
private final ReplicationMonitor replicationMonitor; private final ReplicationMonitor replicationMonitor;
private final long availabilityTimeoutMillis; private final long availabilityTimeoutMillis;
private final LeaderProvider leaderProvider; private final LeaderProvider leaderProvider;


public RaftReplicator( LeaderLocator leaderLocator, MemberId me, Outbound<MemberId,RaftMessages.RaftMessage> outbound, LocalSessionPool sessionPool, public RaftReplicator( LeaderLocator leaderLocator, MemberId me, Outbound<MemberId,RaftMessages.RaftMessage> outbound, LocalSessionPool sessionPool,
ProgressTracker progressTracker, TimeoutStrategy progressTimeoutStrategy, long availabilityTimeoutMillis, AvailabilityGuard availabilityGuard, ProgressTracker progressTracker, TimeoutStrategy progressTimeoutStrategy, long availabilityTimeoutMillis, AvailabilityGuard availabilityGuard,
LogProvider logProvider, Monitors monitors ) LogProvider logProvider, LocalDatabase localDatabase, Monitors monitors )
{ {
this.me = me; this.me = me;
this.outbound = outbound; this.outbound = outbound;
Expand All @@ -69,6 +71,7 @@ public RaftReplicator( LeaderLocator leaderLocator, MemberId me, Outbound<Member
this.availabilityTimeoutMillis = availabilityTimeoutMillis; this.availabilityTimeoutMillis = availabilityTimeoutMillis;
this.availabilityGuard = availabilityGuard; this.availabilityGuard = availabilityGuard;
this.log = logProvider.getLog( getClass() ); this.log = logProvider.getLog( getClass() );
this.localDatabase = localDatabase;
this.replicationMonitor = monitors.newMonitor( ReplicationMonitor.class ); this.replicationMonitor = monitors.newMonitor( ReplicationMonitor.class );
this.leaderProvider = new LeaderProvider(); this.leaderProvider = new LeaderProvider();
leaderLocator.registerListener( this ); leaderLocator.registerListener( this );
Expand Down Expand Up @@ -165,6 +168,7 @@ else if ( newLeader != null && oldLeader == null )


private void assertDatabaseAvailable() throws ReplicationFailureException private void assertDatabaseAvailable() throws ReplicationFailureException
{ {
localDatabase.assertHealthy( ReplicationFailureException.class );
try try
{ {
availabilityGuard.await( availabilityTimeoutMillis ); availabilityGuard.await( availabilityTimeoutMillis );
Expand Down
Expand Up @@ -24,7 +24,9 @@


public class ReplicationFailureException extends Exception public class ReplicationFailureException extends Exception
{ {
ReplicationFailureException( String message, Throwable cause ) // needs to be public due to reflection
@SuppressWarnings( "WeakerAccess" )
public ReplicationFailureException( String message, Throwable cause )
{ {
super( message, cause ); super( message, cause );
} }
Expand Down
Expand Up @@ -23,14 +23,18 @@
package org.neo4j.causalclustering.core.replication; package org.neo4j.causalclustering.core.replication;


import org.hamcrest.Matchers; import org.hamcrest.Matchers;
import org.junit.Rule; import org.junit.jupiter.api.Assertions;
import org.junit.Test; import org.junit.jupiter.api.BeforeEach;
import org.junit.rules.ExpectedException; import org.junit.jupiter.api.Test;


import java.io.IOException;
import java.util.UUID; import java.util.UUID;
import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future; import java.util.concurrent.Future;
import java.util.function.Supplier;


import org.neo4j.causalclustering.catchup.storecopy.LocalDatabase;
import org.neo4j.causalclustering.catchup.storecopy.StoreFiles;
import org.neo4j.causalclustering.core.consensus.LeaderInfo; import org.neo4j.causalclustering.core.consensus.LeaderInfo;
import org.neo4j.causalclustering.core.consensus.LeaderLocator; import org.neo4j.causalclustering.core.consensus.LeaderLocator;
import org.neo4j.causalclustering.core.consensus.RaftMessages; import org.neo4j.causalclustering.core.consensus.RaftMessages;
Expand All @@ -42,10 +46,15 @@
import org.neo4j.causalclustering.helper.ConstantTimeTimeoutStrategy; import org.neo4j.causalclustering.helper.ConstantTimeTimeoutStrategy;
import org.neo4j.causalclustering.helper.TimeoutStrategy; import org.neo4j.causalclustering.helper.TimeoutStrategy;
import org.neo4j.causalclustering.identity.MemberId; import org.neo4j.causalclustering.identity.MemberId;
import org.neo4j.causalclustering.identity.StoreId;
import org.neo4j.causalclustering.messaging.Message; import org.neo4j.causalclustering.messaging.Message;
import org.neo4j.causalclustering.messaging.Outbound; import org.neo4j.causalclustering.messaging.Outbound;
import org.neo4j.kernel.availability.AvailabilityGuard;
import org.neo4j.kernel.availability.DatabaseAvailabilityGuard; import org.neo4j.kernel.availability.DatabaseAvailabilityGuard;
import org.neo4j.kernel.availability.UnavailableException; import org.neo4j.kernel.availability.UnavailableException;
import org.neo4j.kernel.impl.core.DatabasePanicEventGenerator;
import org.neo4j.kernel.impl.transaction.state.DataSourceManager;
import org.neo4j.kernel.internal.DatabaseHealth;
import org.neo4j.kernel.monitoring.Monitors; import org.neo4j.kernel.monitoring.Monitors;
import org.neo4j.logging.NullLog; import org.neo4j.logging.NullLog;
import org.neo4j.logging.NullLogProvider; import org.neo4j.logging.NullLogProvider;
Expand All @@ -54,24 +63,22 @@
import static java.util.concurrent.TimeUnit.MILLISECONDS; import static java.util.concurrent.TimeUnit.MILLISECONDS;
import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.not; import static org.hamcrest.CoreMatchers.not;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.greaterThan;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.Assert.assertThat; import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.Assert.fail;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.atLeast; import static org.mockito.Mockito.atLeast;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never; import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times; import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.neo4j.graphdb.factory.GraphDatabaseSettings.DEFAULT_DATABASE_NAME; import static org.neo4j.graphdb.factory.GraphDatabaseSettings.DEFAULT_DATABASE_NAME;
import static org.neo4j.test.assertion.Assert.assertEventually; import static org.neo4j.test.assertion.Assert.assertEventually;


public class RaftReplicatorTest class RaftReplicatorTest
{ {
@Rule
public ExpectedException expectedException = ExpectedException.none();

private static final int DEFAULT_TIMEOUT_MS = 15_000; private static final int DEFAULT_TIMEOUT_MS = 15_000;


private LeaderLocator leaderLocator = mock( LeaderLocator.class ); private LeaderLocator leaderLocator = mock( LeaderLocator.class );
Expand All @@ -80,11 +87,21 @@ public class RaftReplicatorTest
private GlobalSession session = new GlobalSession( UUID.randomUUID(), myself ); private GlobalSession session = new GlobalSession( UUID.randomUUID(), myself );
private LocalSessionPool sessionPool = new LocalSessionPool( session ); private LocalSessionPool sessionPool = new LocalSessionPool( session );
private TimeoutStrategy noWaitTimeoutStrategy = new ConstantTimeTimeoutStrategy( 0, MILLISECONDS ); private TimeoutStrategy noWaitTimeoutStrategy = new ConstantTimeTimeoutStrategy( 0, MILLISECONDS );
private DatabaseAvailabilityGuard databaseAvailabilityGuard = private DatabaseAvailabilityGuard databaseAvailabilityGuard;
new DatabaseAvailabilityGuard( DEFAULT_DATABASE_NAME, Clocks.systemClock(), NullLog.getInstance() ); private DatabaseHealth databaseHealth;
private LocalDatabase localDatabase;

@BeforeEach
void setUp() throws IOException
{
databaseAvailabilityGuard = new DatabaseAvailabilityGuard( DEFAULT_DATABASE_NAME, Clocks.systemClock(), NullLog.getInstance() );
databaseHealth = new DatabaseHealth( mock( DatabasePanicEventGenerator.class ), NullLog.getInstance() );
localDatabase = StubLocalDatabase.create( () -> databaseHealth, databaseAvailabilityGuard );
localDatabase.start();
}


@Test @Test
public void shouldSendReplicatedContentToLeader() throws Exception void shouldSendReplicatedContentToLeader() throws Exception
{ {
// given // given
Monitors monitors = new Monitors(); Monitors monitors = new Monitors();
Expand Down Expand Up @@ -118,7 +135,7 @@ public void shouldSendReplicatedContentToLeader() throws Exception
} }


@Test @Test
public void shouldResendAfterTimeout() throws Exception void shouldResendAfterTimeout() throws Exception
{ {
// given // given
Monitors monitors = new Monitors(); Monitors monitors = new Monitors();
Expand Down Expand Up @@ -149,7 +166,7 @@ public void shouldResendAfterTimeout() throws Exception
} }


@Test @Test
public void shouldReleaseSessionWhenFinished() throws Exception void shouldReleaseSessionWhenFinished() throws Exception
{ {
// given // given
CapturingProgressTracker capturedProgress = new CapturingProgressTracker(); CapturingProgressTracker capturedProgress = new CapturingProgressTracker();
Expand Down Expand Up @@ -178,7 +195,7 @@ public void shouldReleaseSessionWhenFinished() throws Exception
} }


@Test @Test
public void stopReplicationOnShutdown() throws InterruptedException void stopReplicationOnShutdown() throws InterruptedException
{ {
// given // given
Monitors monitors = new Monitors(); Monitors monitors = new Monitors();
Expand Down Expand Up @@ -206,7 +223,7 @@ public void stopReplicationOnShutdown() throws InterruptedException
} }


@Test @Test
public void stopReplicationWhenUnavailable() throws InterruptedException void stopReplicationWhenUnavailable() throws InterruptedException
{ {
CapturingProgressTracker capturedProgress = new CapturingProgressTracker(); CapturingProgressTracker capturedProgress = new CapturingProgressTracker();
CapturingOutbound<RaftMessages.RaftMessage> outbound = new CapturingOutbound<>(); CapturingOutbound<RaftMessages.RaftMessage> outbound = new CapturingOutbound<>();
Expand All @@ -226,29 +243,41 @@ public void stopReplicationWhenUnavailable() throws InterruptedException
} }


@Test @Test
public void shouldFailIfNoLeaderIsAvailable() void stopReplicationWhenUnHealthy() throws InterruptedException
{
CapturingProgressTracker capturedProgress = new CapturingProgressTracker();
CapturingOutbound<RaftMessages.RaftMessage> outbound = new CapturingOutbound<>();

RaftReplicator replicator = getReplicator( outbound, capturedProgress, new Monitors() );
replicator.onLeaderSwitch( leaderInfo );

ReplicatedInteger content = ReplicatedInteger.valueOf( 5 );
ReplicatingThread replicatingThread = replicatingThread( replicator, content, true );

// when
replicatingThread.start();

databaseHealth.panic( new IllegalStateException( "PANIC" ) );
replicatingThread.join();
Assertions.assertNotNull( replicatingThread.getReplicationException() );
}

@Test
void shouldFailIfNoLeaderIsAvailable()
{ {
// given // given
CapturingProgressTracker capturedProgress = new CapturingProgressTracker(); CapturingProgressTracker capturedProgress = new CapturingProgressTracker();
CapturingOutbound<RaftMessages.RaftMessage> outbound = new CapturingOutbound<>(); CapturingOutbound<RaftMessages.RaftMessage> outbound = new CapturingOutbound<>();


RaftReplicator replicator = getReplicator( outbound, capturedProgress, new Monitors() ); RaftReplicator replicator = getReplicator( outbound, capturedProgress, new Monitors() );
ReplicatedInteger content = ReplicatedInteger.valueOf( 5 );


// when // when
try assertThrows( ReplicationFailureException.class, () -> replicator.replicate( content, true ) );
{
ReplicatedInteger content = ReplicatedInteger.valueOf( 5 );
replicator.replicate( content, true );
fail( "should have thrown" );
}
catch ( ReplicationFailureException ignored )
{
// expected
}
} }


@Test @Test
public void shouldListenToLeaderUpdates() throws ReplicationFailureException void shouldListenToLeaderUpdates() throws ReplicationFailureException
{ {
OneProgressTracker oneProgressTracker = new OneProgressTracker(); OneProgressTracker oneProgressTracker = new OneProgressTracker();
oneProgressTracker.last.setReplicated(); oneProgressTracker.last.setReplicated();
Expand All @@ -271,7 +300,7 @@ public void shouldListenToLeaderUpdates() throws ReplicationFailureException
} }


@Test @Test
public void shouldSuccefulltSendIfLeaderIsLostAndFound() throws InterruptedException void shouldSuccessfullySendIfLeaderIsLostAndFound() throws InterruptedException
{ {
OneProgressTracker capturedProgress = new OneProgressTracker(); OneProgressTracker capturedProgress = new OneProgressTracker();
CapturingOutbound<RaftMessages.RaftMessage> outbound = new CapturingOutbound<>(); CapturingOutbound<RaftMessages.RaftMessage> outbound = new CapturingOutbound<>();
Expand All @@ -297,7 +326,7 @@ public void shouldSuccefulltSendIfLeaderIsLostAndFound() throws InterruptedExcep
private RaftReplicator getReplicator( CapturingOutbound<RaftMessages.RaftMessage> outbound, ProgressTracker progressTracker, Monitors monitors ) private RaftReplicator getReplicator( CapturingOutbound<RaftMessages.RaftMessage> outbound, ProgressTracker progressTracker, Monitors monitors )
{ {
return new RaftReplicator( leaderLocator, myself, outbound, sessionPool, progressTracker, noWaitTimeoutStrategy, 10, databaseAvailabilityGuard, return new RaftReplicator( leaderLocator, myself, outbound, sessionPool, progressTracker, noWaitTimeoutStrategy, 10, databaseAvailabilityGuard,
NullLogProvider.getInstance(), monitors ); NullLogProvider.getInstance(), localDatabase, monitors );
} }


private ReplicatingThread replicatingThread( RaftReplicator replicator, ReplicatedInteger content, boolean trackResult ) private ReplicatingThread replicatingThread( RaftReplicator replicator, ReplicatedInteger content, boolean trackResult )
Expand Down Expand Up @@ -423,4 +452,22 @@ public void send( MemberId to, MESSAGE message, boolean block )
} }


} }

private static class StubLocalDatabase extends LocalDatabase
{
static LocalDatabase create( Supplier<DatabaseHealth> databaseHealthSupplier, AvailabilityGuard availabilityGuard ) throws IOException
{
StoreFiles storeFiles = mock( StoreFiles.class );
when( storeFiles.readStoreId( any() ) ).thenReturn( new StoreId( 1, 2, 3, 4 ) );

DataSourceManager dataSourceManager = mock( DataSourceManager.class );
return new StubLocalDatabase( storeFiles, dataSourceManager, databaseHealthSupplier, availabilityGuard );
}

StubLocalDatabase( StoreFiles storeFiles, DataSourceManager dataSourceManager, Supplier<DatabaseHealth> databaseHealthSupplier,
AvailabilityGuard availabilityGuard )
{
super( null, storeFiles, null, dataSourceManager, databaseHealthSupplier, availabilityGuard, NullLogProvider.getInstance() );
}
}
} }

0 comments on commit da3a460

Please sign in to comment.